anakin87 hysts HF staff Xenova HF staff commited on
Commit
cc1bdc1
ยท
verified ยท
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: hysts <[email protected]>
Co-authored-by: Xenova <[email protected]>

Files changed (5) hide show
  1. .gitattributes +35 -0
  2. README.md +11 -0
  3. app.py +142 -0
  4. requirements.txt +2 -0
  5. style.css +11 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Gemma 2 2B Neogenesis ITA
3
+ emoji: ๐Ÿ’Ž๐Ÿ’ฌ๐Ÿ‡ฎ๐Ÿ‡น
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: Chat with an Italian Small Model
11
+ ---
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # Gemma 2 2B Neogenesis ITA ๐Ÿ’Ž ๐Ÿ’ฌ ๐Ÿ‡ฎ๐Ÿ‡น
12
+
13
+ Fine-tuned version of Google/gemma-2-2b-it to improve the performance on the Italian language.
14
+ Small (2.6 B parameters) but good model, with 8k context length.
15
+
16
+ [๐Ÿชช **Model card**](https://huggingface.co/anakin87/gemma-2-2b-neogenesis-ita)
17
+ TODO: add Kaggle link
18
+ """
19
+
20
+ MAX_MAX_NEW_TOKENS = 2048
21
+ DEFAULT_MAX_NEW_TOKENS = 1024
22
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
23
+
24
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+
26
+ model_id = "anakin87/gemma-2-2b-neogenesis-ita"
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ device_map="auto",
31
+ torch_dtype=torch.bfloat16,
32
+ )
33
+ model.config.sliding_window = 4096
34
+ model.eval()
35
+
36
+
37
+ @spaces.GPU
38
+ def generate(
39
+ message: str,
40
+ chat_history: list[dict],
41
+ max_new_tokens: int = 1024,
42
+ temperature: float = 0.6,
43
+ top_p: float = 0.9,
44
+ top_k: int = 50,
45
+ repetition_penalty: float = 1.2,
46
+ ) -> Iterator[str]:
47
+ conversation = chat_history.copy()
48
+ conversation.append({"role": "user", "content": message})
49
+
50
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
51
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
52
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
53
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
54
+ input_ids = input_ids.to(model.device)
55
+
56
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
57
+ generate_kwargs = dict(
58
+ {"input_ids": input_ids},
59
+ streamer=streamer,
60
+ max_new_tokens=max_new_tokens,
61
+ do_sample=True,
62
+ top_p=top_p,
63
+ top_k=top_k,
64
+ temperature=temperature,
65
+ num_beams=1,
66
+ repetition_penalty=repetition_penalty,
67
+ )
68
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
69
+ t.start()
70
+
71
+ outputs = []
72
+ for text in streamer:
73
+ outputs.append(text)
74
+ yield "".join(outputs)
75
+
76
+
77
+ chat_interface = gr.ChatInterface(
78
+ fn=generate,
79
+ additional_inputs=[
80
+ gr.Slider(
81
+ label="Max new tokens",
82
+ minimum=1,
83
+ maximum=MAX_MAX_NEW_TOKENS,
84
+ step=1,
85
+ value=DEFAULT_MAX_NEW_TOKENS,
86
+ ),
87
+ gr.Slider(
88
+ label="Temperature",
89
+ minimum=0.1,
90
+ maximum=4.0,
91
+ step=0.1,
92
+ value=0.6,
93
+ ),
94
+ gr.Slider(
95
+ label="Top-p (nucleus sampling)",
96
+ minimum=0.05,
97
+ maximum=1.0,
98
+ step=0.05,
99
+ value=0.9,
100
+ ),
101
+ gr.Slider(
102
+ label="Top-k",
103
+ minimum=1,
104
+ maximum=1000,
105
+ step=1,
106
+ value=50,
107
+ ),
108
+ gr.Slider(
109
+ label="Repetition penalty",
110
+ minimum=1.0,
111
+ maximum=2.0,
112
+ step=0.05,
113
+ value=1.2,
114
+ ),
115
+ ],
116
+ stop_btn=None,
117
+ examples=[
118
+ ["Ciao! Come stai?"],
119
+ ["Pro e contro di una relazione a lungo termine. Elenco puntato con max 3 pro e 3 contro sintetici."],
120
+ ["Quante ore impiega un uomo per mangiare un elicottero?"],
121
+ ["Come si apre un file JSON in Python?"],
122
+ ["Fammi un elenco puntato dei pro e contro di vivere in Italia. Massimo 2 pro e 2 contro."],
123
+ ["Inventa una breve storia con animali sul valore dell'amicizia."],
124
+ ["Scrivi un articolo di 100 parole sui 'Benefici dell'open-source nella ricerca sull'intelligenza artificiale'"],
125
+ ["Can you explain briefly to me what is the Python programming language?"],
126
+ ["How many hours does it take a man to eat a Helicopter?"],
127
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
128
+ ],
129
+ cache_examples=False,
130
+ type="messages",
131
+ )
132
+
133
+ fonts = {"font":[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
134
+ "font_mono": [gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", "monospace"]}
135
+
136
+ with gr.Blocks(css_paths="style.css", fill_height=True, theme=gr.themes.Soft(**fonts)) as demo:
137
+ gr.Markdown(DESCRIPTION)
138
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
139
+ chat_interface.render()
140
+
141
+ if __name__ == "__main__":
142
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ accelerate
2
+ transformers>=4.44.2
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }