Hjgugugjhuhjggg commited on
Commit
14bbbee
·
verified ·
1 Parent(s): d27b777

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -42
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import StreamingResponse
5
- from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
8
  pipeline,
@@ -11,25 +11,24 @@ from transformers import (
11
  GenerationConfig,
12
  StoppingCriteriaList
13
  )
14
- import boto3
15
- import uvicorn
16
  import asyncio
17
  from io import BytesIO
18
- from transformers import pipeline
19
 
 
 
 
 
20
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
21
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
22
  AWS_REGION = os.getenv("AWS_REGION")
23
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
24
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
25
 
26
- s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
27
-
28
  app = FastAPI()
29
 
30
  class GenerateRequest(BaseModel):
31
  model_name: str
32
- input_text: str = ""
33
  task_type: str
34
  temperature: float = 1.0
35
  max_new_tokens: int = 200
@@ -42,19 +41,6 @@ class GenerateRequest(BaseModel):
42
  chunk_delay: float = 0.0
43
  stop_sequences: list[str] = []
44
 
45
- @field_validator("model_name")
46
- def model_name_cannot_be_empty(cls, v):
47
- if not v:
48
- raise ValueError("model_name cannot be empty.")
49
- return v
50
-
51
- @field_validator("task_type")
52
- def task_type_must_be_valid(cls, v):
53
- valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
54
- if v not in valid_types:
55
- raise ValueError(f"task_type must be one of: {valid_types}")
56
- return v
57
-
58
  class S3ModelLoader:
59
  def __init__(self, bucket_name, s3_client):
60
  self.bucket_name = bucket_name
@@ -62,32 +48,29 @@ class S3ModelLoader:
62
 
63
  def _get_s3_uri(self, model_name):
64
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
65
-
66
  async def load_model_and_tokenizer(self, model_name):
 
 
 
67
  s3_uri = self._get_s3_uri(model_name)
68
  try:
69
- config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
70
- model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=True)
71
- tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
72
 
73
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
74
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
75
 
76
- return model, tokenizer
77
- except EnvironmentError:
78
- try:
79
- config = AutoConfig.from_pretrained(model_name)
80
- tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
81
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
82
 
83
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
84
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
85
-
86
- model.save_pretrained(s3_uri)
87
- tokenizer.save_pretrained(s3_uri)
88
- return model, tokenizer
89
- except Exception as e:
90
- raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
91
 
92
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
93
 
@@ -96,7 +79,6 @@ async def generate(request: GenerateRequest):
96
  try:
97
  model_name = request.model_name
98
  input_text = request.input_text
99
- task_type = request.task_type
100
  temperature = request.temperature
101
  max_new_tokens = request.max_new_tokens
102
  stream = request.stream
@@ -108,7 +90,13 @@ async def generate(request: GenerateRequest):
108
  chunk_delay = request.chunk_delay
109
  stop_sequences = request.stop_sequences
110
 
111
- model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
 
 
 
 
 
 
112
  device = "cuda" if torch.cuda.is_available() else "cpu"
113
  model.to(device)
114
 
@@ -239,4 +227,5 @@ async def generate_video(request: GenerateRequest):
239
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
240
 
241
  if __name__ == "__main__":
 
242
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import StreamingResponse
5
+ from pydantic import BaseModel
6
  from transformers import (
7
  AutoConfig,
8
  pipeline,
 
11
  GenerationConfig,
12
  StoppingCriteriaList
13
  )
 
 
14
  import asyncio
15
  from io import BytesIO
 
16
 
17
+ # Diccionario global para almacenar los tokens
18
+ token_dict = {}
19
+
20
+ # Setup para acceder a modelos en Hugging Face o S3
21
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
22
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
23
  AWS_REGION = os.getenv("AWS_REGION")
24
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
25
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
26
 
 
 
27
  app = FastAPI()
28
 
29
  class GenerateRequest(BaseModel):
30
  model_name: str
31
+ input_text: str
32
  task_type: str
33
  temperature: float = 1.0
34
  max_new_tokens: int = 200
 
41
  chunk_delay: float = 0.0
42
  stop_sequences: list[str] = []
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  class S3ModelLoader:
45
  def __init__(self, bucket_name, s3_client):
46
  self.bucket_name = bucket_name
 
48
 
49
  def _get_s3_uri(self, model_name):
50
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
51
+
52
  async def load_model_and_tokenizer(self, model_name):
53
+ if model_name in token_dict:
54
+ return token_dict[model_name]
55
+
56
  s3_uri = self._get_s3_uri(model_name)
57
  try:
58
+ model = AutoModelForCausalLM.from_pretrained(s3_uri, local_files_only=True)
59
+ tokenizer = AutoTokenizer.from_pretrained(s3_uri, local_files_only=True)
 
60
 
61
+ if tokenizer.eos_token_id is None:
62
+ tokenizer.eos_token_id = tokenizer.pad_token_id
63
 
64
+ token_dict[model_name] = {
65
+ "model": model,
66
+ "tokenizer": tokenizer,
67
+ "pad_token_id": tokenizer.pad_token_id,
68
+ "eos_token_id": tokenizer.eos_token_id
69
+ }
70
 
71
+ return token_dict[model_name]
72
+ except Exception as e:
73
+ raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
 
 
 
 
 
74
 
75
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
76
 
 
79
  try:
80
  model_name = request.model_name
81
  input_text = request.input_text
 
82
  temperature = request.temperature
83
  max_new_tokens = request.max_new_tokens
84
  stream = request.stream
 
90
  chunk_delay = request.chunk_delay
91
  stop_sequences = request.stop_sequences
92
 
93
+ # Cargar modelo y tokenizer desde el S3
94
+ model_data = await model_loader.load_model_and_tokenizer(model_name)
95
+ model = model_data["model"]
96
+ tokenizer = model_data["tokenizer"]
97
+ pad_token_id = model_data["pad_token_id"]
98
+ eos_token_id = model_data["eos_token_id"]
99
+
100
  device = "cuda" if torch.cuda.is_available() else "cpu"
101
  model.to(device)
102
 
 
227
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
228
 
229
  if __name__ == "__main__":
230
+ import uvicorn
231
  uvicorn.run(app, host="0.0.0.0", port=7860)