raduqus commited on
Commit
7599399
·
verified ·
1 Parent(s): 60ccf92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -20
app.py CHANGED
@@ -1,17 +1,27 @@
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
5
- import random
6
- import numpy as np
7
-
8
- app = FastAPI()
9
 
10
  # Configuration
11
  model_id = "raduqus/reco_1b_16bit"
12
  device = "cuda"
13
  MAX_SEED = np.iinfo(np.int32).max
14
 
 
 
 
 
 
 
 
 
 
15
  def infer(
16
  prompt,
17
  negative_prompt=None,
@@ -28,18 +38,6 @@ def infer(
28
  # Set random generator
29
  generator = torch.Generator().manual_seed(seed)
30
 
31
- # Load model
32
- tokenizer = AutoTokenizer.from_pretrained(model_id)
33
- model = AutoModelForCausalLM.from_pretrained(
34
- model_id,
35
- torch_dtype=torch.float16,
36
- variant="fp16",
37
- use_safetensors=True
38
- )
39
-
40
- # Move to GPU
41
- model = model.to(device)
42
-
43
  # Generate
44
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
45
  outputs = model.generate(
@@ -56,6 +54,9 @@ def infer(
56
 
57
  return generated_text
58
 
 
 
 
59
  class RecommendationRequest(BaseModel):
60
  prompt: str
61
  negative_prompt: str = None
@@ -65,7 +66,7 @@ class RecommendationRequest(BaseModel):
65
  temperature: float = 0.7
66
  top_p: float = 0.9
67
 
68
- @app.post("/recommend")
69
  async def recommend_task(request: RecommendationRequest):
70
  try:
71
  result = infer(
@@ -81,6 +82,28 @@ async def recommend_task(request: RecommendationRequest):
81
  except Exception as e:
82
  raise HTTPException(status_code=500, detail=str(e))
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if __name__ == "__main__":
85
- import uvicorn
86
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import torch
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ import uvicorn
9
+ import threading
 
 
 
10
 
11
  # Configuration
12
  model_id = "raduqus/reco_1b_16bit"
13
  device = "cuda"
14
  MAX_SEED = np.iinfo(np.int32).max
15
 
16
+ # Load model globally
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_id,
20
+ torch_dtype=torch.float16,
21
+ variant="fp16",
22
+ use_safetensors=True
23
+ ).to(device)
24
+
25
  def infer(
26
  prompt,
27
  negative_prompt=None,
 
38
  # Set random generator
39
  generator = torch.Generator().manual_seed(seed)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Generate
42
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
43
  outputs = model.generate(
 
54
 
55
  return generated_text
56
 
57
+ # FastAPI for API endpoint
58
+ api = FastAPI()
59
+
60
  class RecommendationRequest(BaseModel):
61
  prompt: str
62
  negative_prompt: str = None
 
66
  temperature: float = 0.7
67
  top_p: float = 0.9
68
 
69
+ @api.post("/recommend")
70
  async def recommend_task(request: RecommendationRequest):
71
  try:
72
  result = infer(
 
82
  except Exception as e:
83
  raise HTTPException(status_code=500, detail=str(e))
84
 
85
+ # Gradio Interface
86
+ def gradio_infer(prompt):
87
+ return infer(prompt)
88
+
89
+ # Create Gradio interface
90
+ iface = gr.Interface(
91
+ fn=gradio_infer,
92
+ inputs=gr.Textbox(label="Prompt"),
93
+ outputs=gr.Textbox(label="Recommendation"),
94
+ title="Task Recommender",
95
+ description="Generate task recommendations"
96
+ )
97
+
98
+ # Function to start FastAPI server
99
+ def start_api_server():
100
+ uvicorn.run(api, host="0.0.0.0", port=7860)
101
+
102
+ # Main execution
103
  if __name__ == "__main__":
104
+ # Start API server in a separate thread
105
+ api_thread = threading.Thread(target=start_api_server)
106
+ api_thread.start()
107
+
108
+ # Launch Gradio interface
109
+ iface.launch()