paulilioaica commited on
Commit
7c23eef
·
verified ·
1 Parent(s): eabd9a9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +32 -13
README.md CHANGED
@@ -50,22 +50,41 @@ experts:
50
 
51
  ```python
52
  !pip install -qU transformers bitsandbytes accelerate
53
-
54
- from transformers import AutoTokenizer
55
- import transformers
56
  import torch
57
 
58
- model = "paulilioaica/PhiMiX-2x2B"
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- tokenizer = AutoTokenizer.from_pretrained(model)
61
- pipeline = transformers.pipeline(
62
- "text-generation",
63
- model=model,
64
- model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True},
65
  )
66
 
67
- messages = [{"role": "user", "content": "Explain what a Mixture of Experts is in less than 100 words."}]
68
- prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
69
- outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
70
- print(outputs[0]["generated_text"])
 
 
 
 
 
 
 
 
 
 
71
  ```
 
50
 
51
  ```python
52
  !pip install -qU transformers bitsandbytes accelerate
53
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
 
 
54
  import torch
55
 
56
+ model_name = "paulilioaica/PhiMiX-2x2B"
57
+
58
+ torch.set_default_device("cuda")
59
+
60
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
61
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
62
+
63
+ instruction = '''
64
+ def print_prime(n):
65
+ """
66
+ Print all primes between 1 and n
67
+ """
68
+ '''
69
+
70
 
71
+ tokenizer = AutoTokenizer.from_pretrained(
72
+ f"{model_name}",
73
+ trust_remote_code=True
 
 
74
  )
75
 
76
+ # Tokenize the input string
77
+ inputs = tokenizer(
78
+ instruction,
79
+ return_tensors="pt",
80
+ return_attention_mask=False
81
+ )
82
+
83
+ # Generate text using the model
84
+ outputs = model.generate(**inputs, max_length=200)
85
+
86
+ # Decode and print the output
87
+ text = tokenizer.batch_decode(outputs)[0]
88
+ print(text)
89
+
90
  ```