parth parekh commited on
Commit
4460c63
β€’
1 Parent(s): 22bc6d2

added chat endpoint

Browse files
Files changed (1) hide show
  1. main.py +45 -17
main.py CHANGED
@@ -6,7 +6,7 @@ from pydantic import BaseModel
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from dotenv import load_dotenv
8
  from accelerate import Accelerator
9
-
10
  # Load environment variables from a .env file (useful for local development)
11
  load_dotenv()
12
 
@@ -59,45 +59,43 @@ app = FastAPI(
59
  docs_url="/", # URL for Swagger docs
60
  redoc_url="/doc" # URL for ReDoc docs
61
  )
62
- # Set your Hugging Face token from environment variable
63
- HF_TOKEN = os.getenv("HF_TOKEN")
64
 
 
65
  MODEL = "meta-llama/Llama-3.2-1B-Instruct"
66
-
67
- # Auto-select CPU or GPU
68
  device = "cuda" if torch.cuda.is_available() else "cpu"
69
  print(f"Using device: {device}")
70
 
71
- # Set PyTorch to use all available CPU cores if running on CPU
72
  torch.set_num_threads(multiprocessing.cpu_count())
73
-
74
- # Initialize Accelerator for managing device allocation
75
  accelerator = Accelerator()
76
 
77
- # Load model and tokenizer
78
  tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN, use_fast=True)
79
  model = AutoModelForCausalLM.from_pretrained(
80
  MODEL,
81
  token=HF_TOKEN,
82
  torch_dtype=torch.float16,
83
- device_map=device,
84
- low_cpu_mem_usage=True,
85
-
86
  )
87
 
88
- # Prepare model for multi-device setup with accelerate
89
  model, tokenizer = accelerator.prepare(model, tokenizer)
90
-
91
- # Pydantic model for input
92
  class PromptRequest(BaseModel):
93
  prompt: str
94
  max_new_tokens: int = 100
95
  temperature: float = 0.7
96
 
 
 
 
 
 
 
 
 
 
97
  @app.post("/generate/")
98
  async def generate_text(request: PromptRequest):
99
  inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
100
-
101
  with torch.no_grad():
102
  outputs = model.generate(
103
  **inputs,
@@ -106,6 +104,36 @@ async def generate_text(request: PromptRequest):
106
  do_sample=False,
107
  pad_token_id=tokenizer.eos_token_id
108
  )
109
-
110
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
111
  return {"response": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from dotenv import load_dotenv
8
  from accelerate import Accelerator
9
+ from typing import List, Tuple
10
  # Load environment variables from a .env file (useful for local development)
11
  load_dotenv()
12
 
 
59
  docs_url="/", # URL for Swagger docs
60
  redoc_url="/doc" # URL for ReDoc docs
61
  )
 
 
62
 
63
+ HF_TOKEN = os.getenv("HF_TOKEN")
64
  MODEL = "meta-llama/Llama-3.2-1B-Instruct"
 
 
65
  device = "cuda" if torch.cuda.is_available() else "cpu"
66
  print(f"Using device: {device}")
67
 
 
68
  torch.set_num_threads(multiprocessing.cpu_count())
 
 
69
  accelerator = Accelerator()
70
 
 
71
  tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN, use_fast=True)
72
  model = AutoModelForCausalLM.from_pretrained(
73
  MODEL,
74
  token=HF_TOKEN,
75
  torch_dtype=torch.float16,
76
+ device_map=device
 
 
77
  )
78
 
 
79
  model, tokenizer = accelerator.prepare(model, tokenizer)
80
+ # Pydantic models for request validation
 
81
  class PromptRequest(BaseModel):
82
  prompt: str
83
  max_new_tokens: int = 100
84
  temperature: float = 0.7
85
 
86
+ class ChatRequest(BaseModel):
87
+ message: str
88
+ history: List[Tuple[str, str]] = []
89
+ max_new_tokens: int = 100
90
+ temperature: float = 0.7
91
+ system_prompt: str = "You are a helpful assistant."
92
+
93
+
94
+ # Endpoints
95
  @app.post("/generate/")
96
  async def generate_text(request: PromptRequest):
97
  inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
98
+
99
  with torch.no_grad():
100
  outputs = model.generate(
101
  **inputs,
 
104
  do_sample=False,
105
  pad_token_id=tokenizer.eos_token_id
106
  )
107
+
108
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
109
  return {"response": response}
110
+
111
+ @app.post("/chat/")
112
+ async def chat(request: ChatRequest):
113
+ conversation = [
114
+ {"role": "system", "content": request.system_prompt}
115
+ ]
116
+ for human, assistant in request.history:
117
+ conversation.extend([
118
+ {"role": "user", "content": human},
119
+ {"role": "assistant", "content": assistant}
120
+ ])
121
+ conversation.append({"role": "user", "content": request.message})
122
+
123
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(device)
124
+
125
+ with torch.no_grad():
126
+ outputs = model.generate(
127
+ input_ids,
128
+ max_new_tokens=request.max_new_tokens,
129
+ temperature=request.temperature,
130
+ do_sample=False,
131
+ pad_token_id=tokenizer.eos_token_id
132
+ )
133
+
134
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
135
+
136
+ # Extract only the assistant's response
137
+ assistant_response = response.split("Assistant:")[-1].strip()
138
+
139
+ return {"response": assistant_response}