kousw commited on
Commit
74ee384
ยท
verified ยท
1 Parent(s): b37a03b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -1
README.md CHANGED
@@ -1,3 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
- ---
 
 
 
1
+ ![image](x1.png)
2
+
3
+ This model employs the technique described in ["Chat Vector: A Simple Approach to Equip LLMs with Instruction Following and Model Alignment in New Languages"](https://arxiv.org/abs/2310.04799).
4
+
5
+ It is based on [stablelm-gamma-7b](https://huggingface.co/stabilityai/japanese-stablelm-base-gamma-7b), a model that has not undergone instruction tuning, which was pre-trained using [mistral-7b-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1).
6
+
7
+ To extract chat vectors, mistral-7b-v0.1 was "subtracted" from [mistral-7b-instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2).
8
+
9
+ By applying these extracted chat vectors to the non-instruction-tuned model stablelm-gamma-7b, an effect equivalent to instruction tuning is achieved.
10
+
11
+
12
+ ```python
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+
15
+ device = "cuda" # the device to load the model onto
16
+
17
+ model = AutoModelForCausalLM.from_pretrained("kousw/stablelm-gamma-7b-chatvector")
18
+ tokenizer = AutoTokenizer.from_pretrained("kousw/stablelm-gamma-7b-chatvector")
19
+
20
+ messages = [
21
+ {"role": "user", "content": "ไธŽใˆใ‚‰ใ‚ŒใŸใ“ใจใ‚ใ–ใฎๆ„ๅ‘ณใ‚’ๅฐๅญฆ็”Ÿใงใ‚‚ๅˆ†ใ‹ใ‚‹ใ‚ˆใ†ใซๆ•™ใˆใฆใใ ใ•ใ„ใ€‚"},
22
+ {"role": "assistant", "content": "ใฏใ„ใ€ใฉใ‚“ใชใ“ใจใ‚ใ–ใงใ‚‚ใ‚ใ‹ใ‚Šใ‚„ใ™ใ็ญ”ใˆใพใ™"},
23
+ {"role": "user", "content": "ๆƒ…ใ‘ใฏไบบใฎใŸใ‚ใชใ‚‰ใš"}
24
+ ]
25
+
26
+ encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
27
+
28
+ model_inputs = encodeds.to(device)
29
+ model.to(device)
30
+
31
+ generated_ids = model.generate(model_inputs, max_new_tokens=256, do_sample=True)
32
+ decoded = tokenizer.batch_decode(generated_ids)
33
+ print(decoded[0])
34
+ ```
35
+
36
  ---
37
  license: apache-2.0
38
+ language:
39
+ - ja
40
+ ---