Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,27 @@
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
-
import
|
5 |
-
import
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
app = FastAPI()
|
9 |
|
10 |
# Configuration
|
11 |
model_id = "raduqus/reco_1b_16bit"
|
12 |
device = "cuda"
|
13 |
MAX_SEED = np.iinfo(np.int32).max
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def infer(
|
16 |
prompt,
|
17 |
negative_prompt=None,
|
@@ -28,18 +38,6 @@ def infer(
|
|
28 |
# Set random generator
|
29 |
generator = torch.Generator().manual_seed(seed)
|
30 |
|
31 |
-
# Load model
|
32 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
33 |
-
model = AutoModelForCausalLM.from_pretrained(
|
34 |
-
model_id,
|
35 |
-
torch_dtype=torch.float16,
|
36 |
-
variant="fp16",
|
37 |
-
use_safetensors=True
|
38 |
-
)
|
39 |
-
|
40 |
-
# Move to GPU
|
41 |
-
model = model.to(device)
|
42 |
-
|
43 |
# Generate
|
44 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
45 |
outputs = model.generate(
|
@@ -56,6 +54,9 @@ def infer(
|
|
56 |
|
57 |
return generated_text
|
58 |
|
|
|
|
|
|
|
59 |
class RecommendationRequest(BaseModel):
|
60 |
prompt: str
|
61 |
negative_prompt: str = None
|
@@ -65,7 +66,7 @@ class RecommendationRequest(BaseModel):
|
|
65 |
temperature: float = 0.7
|
66 |
top_p: float = 0.9
|
67 |
|
68 |
-
@
|
69 |
async def recommend_task(request: RecommendationRequest):
|
70 |
try:
|
71 |
result = infer(
|
@@ -81,6 +82,28 @@ async def recommend_task(request: RecommendationRequest):
|
|
81 |
except Exception as e:
|
82 |
raise HTTPException(status_code=500, detail=str(e))
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
if __name__ == "__main__":
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
from fastapi import FastAPI, HTTPException
|
6 |
from pydantic import BaseModel
|
7 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
8 |
+
import uvicorn
|
9 |
+
import threading
|
|
|
|
|
|
|
10 |
|
11 |
# Configuration
|
12 |
model_id = "raduqus/reco_1b_16bit"
|
13 |
device = "cuda"
|
14 |
MAX_SEED = np.iinfo(np.int32).max
|
15 |
|
16 |
+
# Load model globally
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(
|
19 |
+
model_id,
|
20 |
+
torch_dtype=torch.float16,
|
21 |
+
variant="fp16",
|
22 |
+
use_safetensors=True
|
23 |
+
).to(device)
|
24 |
+
|
25 |
def infer(
|
26 |
prompt,
|
27 |
negative_prompt=None,
|
|
|
38 |
# Set random generator
|
39 |
generator = torch.Generator().manual_seed(seed)
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
# Generate
|
42 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
43 |
outputs = model.generate(
|
|
|
54 |
|
55 |
return generated_text
|
56 |
|
57 |
+
# FastAPI for API endpoint
|
58 |
+
api = FastAPI()
|
59 |
+
|
60 |
class RecommendationRequest(BaseModel):
|
61 |
prompt: str
|
62 |
negative_prompt: str = None
|
|
|
66 |
temperature: float = 0.7
|
67 |
top_p: float = 0.9
|
68 |
|
69 |
+
@api.post("/recommend")
|
70 |
async def recommend_task(request: RecommendationRequest):
|
71 |
try:
|
72 |
result = infer(
|
|
|
82 |
except Exception as e:
|
83 |
raise HTTPException(status_code=500, detail=str(e))
|
84 |
|
85 |
+
# Gradio Interface
|
86 |
+
def gradio_infer(prompt):
|
87 |
+
return infer(prompt)
|
88 |
+
|
89 |
+
# Create Gradio interface
|
90 |
+
iface = gr.Interface(
|
91 |
+
fn=gradio_infer,
|
92 |
+
inputs=gr.Textbox(label="Prompt"),
|
93 |
+
outputs=gr.Textbox(label="Recommendation"),
|
94 |
+
title="Task Recommender",
|
95 |
+
description="Generate task recommendations"
|
96 |
+
)
|
97 |
+
|
98 |
+
# Function to start FastAPI server
|
99 |
+
def start_api_server():
|
100 |
+
uvicorn.run(api, host="0.0.0.0", port=7860)
|
101 |
+
|
102 |
+
# Main execution
|
103 |
if __name__ == "__main__":
|
104 |
+
# Start API server in a separate thread
|
105 |
+
api_thread = threading.Thread(target=start_api_server)
|
106 |
+
api_thread.start()
|
107 |
+
|
108 |
+
# Launch Gradio interface
|
109 |
+
iface.launch()
|