kamran-r123 commited on
Commit
dbcfd8e
·
verified ·
1 Parent(s): 95b847a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +44 -7
main.py CHANGED
@@ -3,20 +3,57 @@ from pydantic import BaseModel
3
  from huggingface_hub import InferenceClient
4
  import uvicorn
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- app = FastAPI()
8
 
9
- client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
 
10
 
11
  class Item(BaseModel):
12
  prompt: str
13
  history: list
14
  system_prompt: str
15
- temperature: float = 0.7
16
- max_new_tokens: int = 512
17
- top_p: float = 0.15
18
- repetition_penalty: float = 1.0
19
- seed: int = 42
 
20
 
21
  def format_prompt(message, history):
22
  prompt = "<s>"
 
3
  from huggingface_hub import InferenceClient
4
  import uvicorn
5
 
6
+ # **************************************************
7
+ # import transformers
8
+ # import torch
9
+
10
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
11
+
12
+ # pipeline = transformers.pipeline(
13
+ # "text-generation",
14
+ # model=model_id,
15
+ # model_kwargs={"torch_dtype": torch.bfloat16},
16
+ # device_map="auto",
17
+ # )
18
+
19
+ def generate(item: Item):
20
+ messages = [
21
+ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
22
+ {"role": "user", "content": "Who are you?"},
23
+ ]
24
+
25
+ terminators = [
26
+ pipeline.tokenizer.eos_token_id,
27
+ pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
28
+ ]
29
+
30
+ outputs = pipeline(
31
+ messages,
32
+ max_new_tokens=item.max_new_tokens,
33
+ eos_token_id=terminators,
34
+ do_sample=True,
35
+ temperature=item.temperature,
36
+ top_p=item.top_p,
37
+ )
38
+ return outputs[0]["generated_text"][-1]
39
+
40
+ # **************************************************
41
 
 
42
 
43
+
44
+
45
+ client = InferenceClient(model_id)
46
 
47
  class Item(BaseModel):
48
  prompt: str
49
  history: list
50
  system_prompt: str
51
+ temperature: float = 0.6
52
+ max_new_tokens: int = 1024
53
+ top_p: float = 0.95
54
+ seed : int = 42
55
+
56
+ app = FastAPI()
57
 
58
  def format_prompt(message, history):
59
  prompt = "<s>"