raduqus commited on
Commit
60ccf92
·
verified ·
1 Parent(s): b517183

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -32
app.py CHANGED
@@ -2,59 +2,85 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
- import spaces
 
6
 
7
  app = FastAPI()
8
 
 
9
  model_id = "raduqus/reco_1b_16bit"
 
 
10
 
11
- # Initialize ZeroGPU
12
- spaces.gpu()
13
-
14
- try:
15
- # Use spaces.gpu() decorator for initialization
16
- @spaces.GPU
17
- def load_model():
18
- tokenizer = AutoTokenizer.from_pretrained(model_id)
19
- model = AutoModelForCausalLM.from_pretrained(
20
- model_id,
21
- torch_dtype=torch.float16,
22
- device_map="auto"
23
- )
24
- return tokenizer, model
25
-
26
- tokenizer, model = load_model()
27
- print("Model loaded successfully on ZeroGPU.")
28
- except Exception as e:
29
- print(f"Model loading error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  class RecommendationRequest(BaseModel):
32
  prompt: str
 
 
 
33
  max_length: int = 100
34
  temperature: float = 0.7
35
  top_p: float = 0.9
36
 
37
  @app.post("/recommend")
38
- @spaces.GPU # Ensure GPU usage for inference
39
  async def recommend_task(request: RecommendationRequest):
40
  try:
41
- inputs = tokenizer(request.prompt, return_tensors="pt")
42
- outputs = model.generate(
43
- inputs.input_ids.to('cuda'),
 
 
44
  max_length=request.max_length,
45
  temperature=request.temperature,
46
- top_p=request.top_p,
47
- do_sample=True
48
  )
49
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
- return {"recommendation": generated_text}
51
  except Exception as e:
52
  raise HTTPException(status_code=500, detail=str(e))
53
 
54
- @app.get("/")
55
- async def root():
56
- return {"message": "Task recommender is running on ZeroGPU!"}
57
-
58
  if __name__ == "__main__":
59
  import uvicorn
60
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
+ import random
6
+ import numpy as np
7
 
8
  app = FastAPI()
9
 
10
+ # Configuration
11
  model_id = "raduqus/reco_1b_16bit"
12
+ device = "cuda"
13
+ MAX_SEED = np.iinfo(np.int32).max
14
 
15
+ def infer(
16
+ prompt,
17
+ negative_prompt=None,
18
+ seed=0,
19
+ randomize_seed=True,
20
+ max_length=100,
21
+ temperature=0.7,
22
+ top_p=0.9
23
+ ):
24
+ # Seed handling
25
+ if randomize_seed:
26
+ seed = random.randint(0, MAX_SEED)
27
+
28
+ # Set random generator
29
+ generator = torch.Generator().manual_seed(seed)
30
+
31
+ # Load model
32
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_id,
35
+ torch_dtype=torch.float16,
36
+ variant="fp16",
37
+ use_safetensors=True
38
+ )
39
+
40
+ # Move to GPU
41
+ model = model.to(device)
42
+
43
+ # Generate
44
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
45
+ outputs = model.generate(
46
+ inputs.input_ids,
47
+ max_length=max_length,
48
+ temperature=temperature,
49
+ top_p=top_p,
50
+ do_sample=True,
51
+ generator=generator
52
+ )
53
+
54
+ # Decode
55
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+
57
+ return generated_text
58
 
59
  class RecommendationRequest(BaseModel):
60
  prompt: str
61
+ negative_prompt: str = None
62
+ seed: int = 0
63
+ randomize_seed: bool = True
64
  max_length: int = 100
65
  temperature: float = 0.7
66
  top_p: float = 0.9
67
 
68
  @app.post("/recommend")
 
69
  async def recommend_task(request: RecommendationRequest):
70
  try:
71
+ result = infer(
72
+ prompt=request.prompt,
73
+ negative_prompt=request.negative_prompt,
74
+ seed=request.seed,
75
+ randomize_seed=request.randomize_seed,
76
  max_length=request.max_length,
77
  temperature=request.temperature,
78
+ top_p=request.top_p
 
79
  )
80
+ return {"recommendation": result}
 
81
  except Exception as e:
82
  raise HTTPException(status_code=500, detail=str(e))
83
 
 
 
 
 
84
  if __name__ == "__main__":
85
  import uvicorn
86
  uvicorn.run(app, host="0.0.0.0", port=7860)