raduqus commited on
Commit
27465f4
·
verified ·
1 Parent(s): aa850c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -5
app.py CHANGED
@@ -1,7 +1,58 @@
1
- from vllm import serve
 
 
2
 
3
- # Define your model name
4
- model_name = "raduqus/reco_1b_4bit"
5
 
6
- # Serve the model
7
- serve(model_name=model_name, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # Initialize FastAPI app
6
+ app = FastAPI()
7
 
8
+ # Hugging Face model ID
9
+ model_id = "raduqus/reco_1b_16bit"
10
+
11
+ # Load tokenizer and model
12
+ try:
13
+ print("Loading tokenizer...")
14
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ print("Loading 16-bit model...")
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_id,
18
+ torch_dtype="float16", # Specify 16-bit floating-point precision
19
+ device_map="auto" # Automatically map to available devices
20
+ )
21
+ print("Model loaded successfully.")
22
+ except Exception as e:
23
+ raise RuntimeError(f"Failed to load the model: {e}")
24
+
25
+ # Input schema for task recommendations
26
+ class RecommendationRequest(BaseModel):
27
+ prompt: str
28
+ max_length: int = 100
29
+ temperature: float = 0.7
30
+ top_p: float = 0.9
31
+
32
+ @app.post("/recommend")
33
+ async def recommend_task(request: RecommendationRequest):
34
+ """
35
+ Generate task recommendations based on input prompt.
36
+ """
37
+ try:
38
+ # Encode input and generate response
39
+ inputs = tokenizer(request.prompt, return_tensors="pt", truncation=True, max_length=request.max_length)
40
+ outputs = model.generate(
41
+ inputs["input_ids"],
42
+ max_length=request.max_length,
43
+ temperature=request.temperature,
44
+ top_p=request.top_p,
45
+ do_sample=True
46
+ )
47
+ # Decode generated text
48
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+ return {"recommendation": generated_text}
50
+ except Exception as e:
51
+ raise HTTPException(status_code=500, detail=f"Error during generation: {e}")
52
+
53
+ @app.get("/")
54
+ async def root():
55
+ """
56
+ Health check endpoint.
57
+ """
58
+ return {"message": "Task recommender is running!"}