|
--- |
|
library_name: gemma_torch |
|
license: gemma |
|
license_link: https://ai.google.dev/gemma/terms |
|
pipeline_tag: text-generation |
|
tags: |
|
- pytorch |
|
extra_gated_heading: Access CodeGemma on Hugging Face |
|
extra_gated_prompt: To access CodeGemma on Hugging Face, you’re required to review |
|
and agree to Google’s usage license. To do this, please ensure you’re logged-in |
|
to Hugging Face and click below. Requests are processed immediately. |
|
extra_gated_button_content: Acknowledge license |
|
--- |
|
|
|
# CodeGemma Model Card |
|
|
|
> [!IMPORTANT] |
|
> |
|
> This repository corresponds to the CodeGemma 7B checkpoint for use with [Gemma PyTorch](https://github.com/google/gemma_pytorch). If you're looking for the `transformers` implementation, or more detailed model card, visit https://huggingface.co/google/codegemma-7b. |
|
|
|
**Model Page**: [CodeGemma](https://ai.google.dev/gemma/docs/codegemma) |
|
|
|
**Resources and Technical Documentation**: |
|
|
|
* [Technical Report](https://goo.gle/codegemma) |
|
* [Responsible Generative AI Toolkit](https://ai.google.dev/responsible) |
|
|
|
**Terms of Use**: [Terms](https://www.kaggle.com/models/google/codegemma/license/consent/verify/huggingface?returnModelRepoId=google/codegemma-7b-pytorch) |
|
|
|
**Authors**: Google |
|
|
|
# Sample Usage |
|
|
|
```python |
|
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b |
|
from gemma.model import GemmaForCausalLM |
|
from gemma.tokenizer import Tokenizer |
|
import contextlib |
|
import os |
|
import torch |
|
|
|
VARIANT = "7b" |
|
MACHINE_TYPE = "cpu" |
|
weights_dir = 'codegemma-7b-pytorch' |
|
|
|
@contextlib.contextmanager |
|
def _set_default_tensor_type(dtype: torch.dtype): |
|
"""Sets the default torch dtype to the given dtype.""" |
|
torch.set_default_dtype(dtype) |
|
yield |
|
torch.set_default_dtype(torch.float) |
|
|
|
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b() |
|
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model") |
|
|
|
device = torch.device(MACHINE_TYPE) |
|
with _set_default_tensor_type(model_config.get_dtype()): |
|
model = GemmaForCausalLM(model_config) |
|
ckpt_path = os.path.join(weights_dir, f'codegemma-{VARIANT}.pt') |
|
model.load_weights(ckpt_path) |
|
model = model.to(device).eval() |
|
|
|
FIM_PROMPT = """<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__": |
|
sys.exit(0)<|fim_middle|>""" |
|
|
|
model.generate( |
|
FIM_PROMPT, |
|
device=device, |
|
output_len=100, |
|
) |
|
``` |