Hjgugugjhuhjggg commited on
Commit
6a7d8ad
·
verified ·
1 Parent(s): 757421a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -222
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  import torch
3
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends, BackgroundTasks, Request, Query, APIRouter, Path, Body, status, Response, Header
4
- from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse, PlainTextResponse, RedirectResponse
5
- from pydantic import BaseModel, validator, Field, root_validator, EmailStr, constr, ValidationError
6
  from transformers import (
7
  AutoModelForCausalLM,
8
  AutoTokenizer,
@@ -19,90 +19,55 @@ from transformers import (
19
  AutoModelForTokenClassification,
20
  AutoModelForMaskedLM,
21
  AutoModelForObjectDetection,
22
- AutoModelForSeq2SeqLM
23
  )
24
  from io import BytesIO
25
  import boto3
26
- from botocore.exceptions import NoCredentialsError, ClientError
27
  from huggingface_hub import snapshot_download
28
- import asyncio
29
  import tempfile
30
  import hashlib
31
  from PIL import Image
32
- import base64
33
  from typing import Optional, List, Union, Dict, Any
34
  import uuid
35
- import subprocess
36
- import json
37
- from starlette.middleware.cors import CORSMiddleware
38
- import numpy as np
39
- from typing import Dict, Any
40
  from fastapi.staticfiles import StaticFiles
41
  from fastapi.templating import Jinja2Templates
42
  from fastapi.middleware.gzip import GZipMiddleware
43
- from transformers import AutoImageProcessor, pipeline
44
  from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
45
- from fastapi.security.api_key import APIKeyCookie
46
- from fastapi import Depends, Security, status, APIRouter, UploadFile, File, Request
47
- from fastapi.security import APIKeyHeader, OAuth2PasswordRequestForm
48
- from passlib.context import CryptContext
49
- from jose import JWTError, jwt
50
- from datetime import datetime, timedelta
51
- from starlette.requests import Request
52
- import logging
53
- from pydantic import EmailStr, constr, ValidationError
54
- from database import insert_user, get_user, delete_user, update_user, create_db_and_table
55
- from starlette.middleware import Middleware
56
- from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
57
- from starlette.types import ASGIApp
58
- import uvicorn
59
- from starlette.responses import StreamingResponse
60
- import logging
61
- from pydantic import EmailStr, constr, ValidationError
62
- from database import insert_user, get_user, delete_user, update_user, create_db_and_table, get_all_users
63
- from starlette.middleware import Middleware
64
- from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
65
- from starlette.types import ASGIApp
66
- import uvicorn
67
- from starlette.responses import StreamingResponse
68
- import logging
69
- from fastapi.exceptions import RequestValidationError
70
- from fastapi import Request, status, Depends
71
- from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
72
- from jose import JWTError, jwt
73
- from passlib.context import CryptContext
74
- from datetime import datetime, timedelta
75
- from pydantic import BaseModel, field_validator, model_validator, Field, EmailStr, constr, ValidationError
76
- from typing import Optional, List, Union
77
 
78
- #setting up logging
79
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
80
  logger = logging.getLogger(__name__)
81
 
82
- #JWT Settings
83
  SECRET_KEY = os.getenv("SECRET_KEY")
84
  if not SECRET_KEY:
85
  raise ValueError("SECRET_KEY must be set.")
86
  ALGORITHM = "HS256"
87
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
88
 
89
- #Password Hashing
90
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
91
 
92
- #Database connection - replace with your database setup
93
- #Example using SQLite
94
- import sqlite3
95
  conn = sqlite3.connect('users.db')
96
  cursor = conn.cursor()
 
 
 
 
 
 
 
 
97
 
98
- #OAuth2
99
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
100
-
101
- #API Key
102
  API_KEY = os.getenv("API_KEY")
103
  api_key_header = APIKeyHeader(name="X-API-Key")
104
 
105
- #Configuration
106
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
107
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
108
  AWS_REGION = os.getenv("AWS_REGION")
@@ -115,15 +80,8 @@ TEMPLATES = Jinja2Templates(directory="templates")
115
  app = FastAPI()
116
  app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
117
  app.add_middleware(GZipMiddleware)
 
118
 
119
- origins = ["*"]
120
- app.add_middleware(
121
- CORSMiddleware,
122
- allow_origins=origins,
123
- allow_credentials=True,
124
- allow_methods=["*"],
125
- allow_headers=["*"],
126
- )
127
 
128
  class User(BaseModel):
129
  username: constr(min_length=3, max_length=50)
@@ -131,9 +89,9 @@ class User(BaseModel):
131
  password: constr(min_length=8)
132
 
133
  class GenerateRequest(BaseModel):
134
- model_name: str
135
- input_text: Optional[str] = Field(None, description="Input text for generation.")
136
- task_type: str = Field(..., description="Type of generation task (text, image, audio, video, classification, translation, question-answering, speech-to-text, text-to-speech, image-segmentation, feature-extraction, token-classification, fill-mask, image-inpainting, image-super-resolution, object-detection, image-captioning, audio-transcription, summarization).")
137
  temperature: float = 1.0
138
  max_new_tokens: int = 200
139
  stream: bool = True
@@ -149,10 +107,10 @@ class GenerateRequest(BaseModel):
149
  target_language: Optional[str] = None
150
  context: Optional[str] = None
151
  audio_file: Optional[UploadFile] = None
152
- raw_input: Optional[Union[str, bytes]] = None # for feature extraction
153
- masked_text: Optional[str] = None # for fill-mask
154
- mask_image: Optional[UploadFile] = None # for image inpainting
155
- low_res_image: Optional[UploadFile] = None # for image super-resolution
156
 
157
  @field_validator('task_type')
158
  def validate_task_type(cls, value):
@@ -182,6 +140,7 @@ class GenerateRequest(BaseModel):
182
  raise ValueError("low_res_image is required for image super-resolution.")
183
  return values
184
 
 
185
  class S3ModelLoader:
186
  def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
187
  self.bucket_name = bucket_name
@@ -286,40 +245,43 @@ class S3ModelLoader:
286
  raise ValueError("Unsupported task type")
287
 
288
  async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
289
- encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
290
- input_length = encoded_input["input_ids"].shape[1]
291
- max_length = model.config.max_length
292
- remaining_tokens = max_length - input_length
293
- if remaining_tokens <= 0:
294
- yield ""
295
- generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
296
- def stop_criteria(input_ids, scores):
297
- decoded_output = tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
298
- return decoded_output in stop_sequences
299
- stopping_criteria = StoppingCriteriaList([stop_criteria])
300
- outputs = model.generate(
301
- **encoded_input,
302
- do_sample=generation_config.do_sample,
303
- max_new_tokens=generation_config.max_new_tokens,
304
- temperature=generation_config.temperature,
305
- top_p=generation_config.top_p,
306
- top_k=generation_config.top_k,
307
- repetition_penalty=generation_config.repetition_penalty,
308
- num_return_sequences=generation_config.num_return_sequences,
309
- stopping_criteria=stopping_criteria,
310
- output_scores=True,
311
- return_dict_in_generate=True
312
- )
313
- for output in outputs.sequences:
314
- for token_id in output:
315
- token = tokenizer.decode(token_id, skip_special_tokens=True)
316
- yield token
 
 
 
317
 
318
 
319
  model_loader = S3ModelLoader(S3_BUCKET_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
320
 
321
  def get_model_data(request: GenerateRequest):
322
- return model_loader.load_model_and_tokenizer(request.model_name, request.task_type)
323
 
324
  async def verify_api_key(api_key: str = Depends(api_key_header)):
325
  if api_key != API_KEY:
@@ -327,7 +289,7 @@ async def verify_api_key(api_key: str = Depends(api_key_header)):
327
 
328
 
329
  @app.post("/generate", dependencies=[Depends(verify_api_key)])
330
- async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, model_data = Depends(get_model_data)):
331
  try:
332
  device = "cuda" if torch.cuda.is_available() else "cpu"
333
  if request.task_type == "text":
@@ -342,31 +304,31 @@ async def generate(request: GenerateRequest, background_tasks: BackgroundTasks,
342
  do_sample=request.do_sample,
343
  num_return_sequences=request.num_return_sequences,
344
  )
345
- async def stream_with_tokens():
346
- async for token in stream_text(model, tokenizer, request.input_text, generation_config, request.stop_sequences, device, request.chunk_delay):
347
- yield f"Token: {token}\n"
348
- return StreamingResponse(stream_with_tokens(), media_type="text/plain")
349
  elif request.task_type in ["image", "audio", "video"]:
350
- pipeline = model_data["pipeline"]
351
- result = pipeline(request.input_text)
352
- if request.task_type == "image":
353
- image = result[0]
354
- img_byte_arr = BytesIO()
355
- image.save(img_byte_arr, format="PNG")
356
- img_byte_arr.seek(0)
357
- return StreamingResponse(img_byte_arr, media_type="image/png")
358
- elif request.task_type == "audio":
359
- audio = result[0]
360
- audio_byte_arr = BytesIO()
361
- audio.save(audio_byte_arr, format="wav")
362
- audio_byte_arr.seek(0)
363
- return StreamingResponse(audio_byte_arr, media_type="audio/wav")
364
- elif request.task_type == "video":
365
- video = result[0]
366
- video_byte_arr = BytesIO()
367
- video.save(video_byte_arr, format="mp4")
368
- video_byte_arr.seek(0)
369
- return StreamingResponse(video_byte_arr, media_type="video/mp4")
 
 
 
370
  elif request.task_type == "classification":
371
  if request.image_file is None:
372
  raise HTTPException(status_code=400, detail="Image file is required for classification.")
@@ -406,38 +368,39 @@ async def generate(request: GenerateRequest, background_tasks: BackgroundTasks,
406
  if request.audio_file is None:
407
  raise HTTPException(status_code=400, detail="Audio file is required for speech-to-text.")
408
  contents = await request.audio_file.read()
409
- pipeline = model_data["pipeline"]
410
  try:
411
- transcription = pipeline(contents, sampling_rate=16000)[0]["text"] # Assuming 16kHz sampling rate
412
  return JSONResponse({"transcription": transcription})
413
  except Exception as e:
414
  raise HTTPException(status_code=500, detail=f"Error during speech-to-text: {str(e)}")
415
-
416
  elif request.task_type == "text-to-speech":
417
  if not request.input_text:
418
  raise HTTPException(status_code=400, detail="Input text is required for text-to-speech.")
419
- pipeline = model_data["pipeline"]
420
  try:
421
- audio = pipeline(request.input_text)[0]
422
  file_path = os.path.join(TEMP_DIR, f"{uuid.uuid4()}.wav")
423
  audio.save(file_path)
424
  background_tasks.add_task(os.remove, file_path)
425
  return FileResponse(file_path, media_type="audio/wav")
426
  except Exception as e:
427
  raise HTTPException(status_code=500, detail=f"Error during text-to-speech: {str(e)}")
428
-
429
  elif request.task_type == "image-segmentation":
430
  if request.image_file is None:
431
  raise HTTPException(status_code=400, detail="Image file is required for image segmentation.")
432
  contents = await request.image_file.read()
433
  image = Image.open(BytesIO(contents)).convert("RGB")
434
- pipeline = model_data["pipeline"]
435
- result = pipeline(image)
436
- mask = result[0]['mask']
437
- mask_byte_arr = BytesIO()
438
- mask.save(mask_byte_arr, format="PNG")
439
- mask_byte_arr.seek(0)
440
- return StreamingResponse(mask_byte_arr, media_type="image/png")
 
 
 
441
  elif request.task_type == "feature-extraction":
442
  if request.raw_input is None:
443
  raise HTTPException(status_code=400, detail="raw_input is required for feature extraction.")
@@ -450,7 +413,7 @@ async def generate(request: GenerateRequest, background_tasks: BackgroundTasks,
450
  inputs = feature_extractor(images=image, return_tensors="pt")
451
  else:
452
  raise ValueError("Unsupported raw_input type.")
453
- features = inputs.pixel_values # Adjust according to your feature extractor
454
  return JSONResponse({"features": features.tolist()})
455
  except Exception as fe:
456
  raise HTTPException(status_code=400, detail=f"Error during feature extraction: {fe}")
@@ -484,70 +447,81 @@ async def generate(request: GenerateRequest, background_tasks: BackgroundTasks,
484
  image_contents = await request.image_file.read()
485
  mask_contents = await request.mask_image.read()
486
  image = Image.open(BytesIO(image_contents)).convert("RGB")
487
- mask = Image.open(BytesIO(mask_contents)).convert("L") # Assuming mask is grayscale
488
- pipeline = model_data["pipeline"]
489
- result = pipeline(image, mask)
490
- inpainted_image = result[0]
491
- img_byte_arr = BytesIO()
492
- inpainted_image.save(img_byte_arr, format="PNG")
493
- img_byte_arr.seek(0)
494
- return StreamingResponse(img_byte_arr, media_type="image/png")
 
 
 
495
  elif request.task_type == "image-super-resolution":
496
  if request.low_res_image is None:
497
  raise HTTPException(status_code=400, detail="low_res_image is required for image super-resolution.")
498
  contents = await request.low_res_image.read()
499
  image = Image.open(BytesIO(contents)).convert("RGB")
500
- pipeline = model_data["pipeline"]
501
- result = pipeline(image)
502
- upscaled_image = result[0]
503
- img_byte_arr = BytesIO()
504
- upscaled_image.save(img_byte_arr, format="PNG")
505
- img_byte_arr.seek(0)
506
- return StreamingResponse(img_byte_arr, media_type="image/png")
 
 
 
507
  elif request.task_type == "object-detection":
508
  if request.image_file is None:
509
  raise HTTPException(status_code=400, detail="Image file is required for object detection.")
510
  contents = await request.image_file.read()
511
  image = Image.open(BytesIO(contents)).convert("RGB")
512
- pipeline = model_data["pipeline"]
513
  image_processor = model_data["image_processor"]
514
  inputs = image_processor(images=image, return_tensors="pt")
515
  with torch.no_grad():
516
- outputs = pipeline(image)
517
- detections = outputs
518
- return JSONResponse({"detections": detections})
 
 
 
519
  elif request.task_type == "image-captioning":
520
  if request.image_file is None:
521
  raise HTTPException(status_code=400, detail="Image file is required for image captioning.")
522
  contents = await request.image_file.read()
523
  image = Image.open(BytesIO(contents)).convert("RGB")
524
- pipeline = model_data["pipeline"]
525
- caption = pipeline(image)[0]['generated_text']
526
- return JSONResponse({"caption": caption})
 
 
 
527
  elif request.task_type == "audio-transcription":
528
  if request.audio_file is None:
529
  raise HTTPException(status_code=400, detail="Audio file is required for audio transcription.")
 
 
530
  try:
531
- contents = await request.audio_file.read()
532
- pipeline = model_data["pipeline"]
533
- try:
534
- transcription = pipeline(contents, sampling_rate=16000)[0]["text"] # Assuming 16kHz sampling rate
535
- return JSONResponse({"transcription": transcription})
536
- except Exception as e:
537
- raise HTTPException(status_code=500, detail=f"Error during audio transcription (pipeline): {str(e)}")
538
  except Exception as e:
539
- raise HTTPException(status_code=500, detail=f"Error during audio transcription (file read): {str(e)}")
540
  elif request.task_type == "summarization":
541
  if request.input_text is None:
542
  raise HTTPException(status_code=400, detail="Input text is required for summarization.")
543
  model = model_data["model"].to(device)
544
  tokenizer = model_data["tokenizer"]
545
- inputs = tokenizer(request.input_text, return_tensors="pt", truncation=True, max_length=512) # added max_length for summarization
546
  with torch.no_grad():
547
- outputs = model.generate(**inputs)
548
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
549
- return JSONResponse({"summary": summary})
550
-
 
 
551
  else:
552
  raise HTTPException(status_code=500, detail=f"Unsupported task type")
553
  except Exception as e:
@@ -563,25 +537,24 @@ async def root(request: Request):
563
  async def health_check():
564
  return {"status": "healthy"}
565
 
566
- # Authentication Endpoints
 
 
567
 
568
  @app.post("/token", response_model=Token)
569
  async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
570
  user = authenticate_user(form_data.username, form_data.password)
571
  if not user:
572
- raise HTTPException(
573
- status_code=status.HTTP_401_UNAUTHORIZED,
574
- detail="Incorrect username or password",
575
- headers={"WWW-Authenticate": "Bearer"},
576
- )
577
  access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
578
  access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)
579
  return {"access_token": access_token, "token_type": "bearer"}
580
 
581
  def authenticate_user(username: str, password: str):
582
- user = get_user(username)
583
- if user and pwd_context.verify(password, user.hashed_password):
584
- return {"username": user.username}
 
585
  return None
586
 
587
  def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None):
@@ -594,30 +567,22 @@ def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None):
594
  encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
595
  return encoded_jwt
596
 
597
- class Token(BaseModel):
598
- access_token: str
599
- token_type: str
600
-
601
 
602
  @app.get("/users/me")
603
  async def read_users_me(current_user: str = Depends(get_current_user)):
604
  return {"username": current_user}
605
 
606
  async def get_current_user(token: str = Depends(oauth2_scheme)):
607
- credentials_exception = HTTPException(
608
- status_code=status.HTTP_401_UNAUTHORIZED,
609
- detail="Could not validate credentials",
610
- headers={"WWW-Authenticate": "Bearer"},
611
- )
612
  try:
613
  payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
614
  username: str = payload.get("sub")
615
  if username is None:
616
  raise credentials_exception
617
- token_data = {"username": username, "token": token}
618
  except JWTError:
619
  raise credentials_exception
620
- user = get_user(username)
 
621
  if user is None:
622
  raise credentials_exception
623
  return username
@@ -627,12 +592,11 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
627
  async def create_user(user: User):
628
  try:
629
  hashed_password = pwd_context.hash(user.password)
630
- new_user = {"username": user.username, "email": user.email, "hashed_password": hashed_password}
631
- inserted_user = insert_user(new_user)
632
- if inserted_user:
633
- return User(**inserted_user)
634
- else:
635
- raise HTTPException(status_code=500, detail="Failed to create user.")
636
  except Exception as e:
637
  logger.error(f"Error creating user: {e}")
638
  raise HTTPException(status_code=500, detail=f"Error creating user: {e}")
@@ -642,27 +606,20 @@ async def create_user(user: User):
642
  async def update_user_data(username: str, user: User):
643
  try:
644
  hashed_password = pwd_context.hash(user.password)
645
- updated_user_data = {"email": user.email, "hashed_password": hashed_password}
646
- updated_user = update_user(username, updated_user_data)
647
- if updated_user:
648
- return User(**updated_user)
649
- else:
650
- raise HTTPException(status_code=404, detail="User not found")
651
-
652
  except Exception as e:
653
  logger.error(f"Error updating user: {e}")
654
  raise HTTPException(status_code=500, detail="Error updating user.")
655
 
656
 
657
-
658
  @app.delete("/users/{username}", dependencies=[Depends(get_current_user)])
659
  async def delete_user_account(username: str):
660
  try:
661
- deleted_user = delete_user(username)
662
- if deleted_user:
663
- return JSONResponse({"message": "User deleted successfully."}, status_code=200)
664
- else:
665
- raise HTTPException(status_code=404, detail="User not found")
666
  except Exception as e:
667
  logger.error(f"Error deleting user: {e}")
668
  raise HTTPException(status_code=500, detail="Error deleting user.")
@@ -670,20 +627,14 @@ async def delete_user_account(username: str):
670
 
671
  @app.get("/users", dependencies=[Depends(get_current_user)])
672
  async def get_all_users_route():
673
- return get_all_users()
674
-
 
675
 
676
 
677
  @app.exception_handler(RequestValidationError)
678
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
679
- return JSONResponse(
680
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
681
- content=json.dumps({"detail": exc.errors(), "body": exc.body}),
682
- )
683
-
684
 
685
  if __name__ == "__main__":
686
-
687
- create_db_and_table() # Initialize database on startup
688
-
689
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
 
1
  import os
2
  import torch
3
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, BackgroundTasks, Request
4
+ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse
5
+ from pydantic import BaseModel, validator, Field, root_validator, EmailStr, constr
6
  from transformers import (
7
  AutoModelForCausalLM,
8
  AutoTokenizer,
 
19
  AutoModelForTokenClassification,
20
  AutoModelForMaskedLM,
21
  AutoModelForObjectDetection,
22
+ AutoImageProcessor,
23
  )
24
  from io import BytesIO
25
  import boto3
26
+ from botocore.exceptions import ClientError
27
  from huggingface_hub import snapshot_download
 
28
  import tempfile
29
  import hashlib
30
  from PIL import Image
 
31
  from typing import Optional, List, Union, Dict, Any
32
  import uuid
33
+ import logging
34
+ import sqlite3
35
+ from passlib.context import CryptContext
36
+ from jose import JWTError, jwt
37
+ from datetime import datetime, timedelta
38
  from fastapi.staticfiles import StaticFiles
39
  from fastapi.templating import Jinja2Templates
40
  from fastapi.middleware.gzip import GZipMiddleware
 
41
  from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
42
+ from starlette.middleware.cors import CORSMiddleware
43
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
45
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
46
  logger = logging.getLogger(__name__)
47
 
 
48
  SECRET_KEY = os.getenv("SECRET_KEY")
49
  if not SECRET_KEY:
50
  raise ValueError("SECRET_KEY must be set.")
51
  ALGORITHM = "HS256"
52
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
53
 
 
54
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
55
 
 
 
 
56
  conn = sqlite3.connect('users.db')
57
  cursor = conn.cursor()
58
+ cursor.execute('''
59
+ CREATE TABLE IF NOT EXISTS users (
60
+ username TEXT PRIMARY KEY,
61
+ email TEXT UNIQUE,
62
+ hashed_password TEXT
63
+ )
64
+ ''')
65
+ conn.commit()
66
 
 
67
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 
 
68
  API_KEY = os.getenv("API_KEY")
69
  api_key_header = APIKeyHeader(name="X-API-Key")
70
 
 
71
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
72
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
73
  AWS_REGION = os.getenv("AWS_REGION")
 
80
  app = FastAPI()
81
  app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
82
  app.add_middleware(GZipMiddleware)
83
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
84
 
 
 
 
 
 
 
 
 
85
 
86
  class User(BaseModel):
87
  username: constr(min_length=3, max_length=50)
 
89
  password: constr(min_length=8)
90
 
91
  class GenerateRequest(BaseModel):
92
+ model_id: str
93
+ input_text: Optional[str] = Field(None)
94
+ task_type: str = Field(...)
95
  temperature: float = 1.0
96
  max_new_tokens: int = 200
97
  stream: bool = True
 
107
  target_language: Optional[str] = None
108
  context: Optional[str] = None
109
  audio_file: Optional[UploadFile] = None
110
+ raw_input: Optional[Union[str, bytes]] = None
111
+ masked_text: Optional[str] = None
112
+ mask_image: Optional[UploadFile] = None
113
+ low_res_image: Optional[UploadFile] = None
114
 
115
  @field_validator('task_type')
116
  def validate_task_type(cls, value):
 
140
  raise ValueError("low_res_image is required for image super-resolution.")
141
  return values
142
 
143
+
144
  class S3ModelLoader:
145
  def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
146
  self.bucket_name = bucket_name
 
245
  raise ValueError("Unsupported task type")
246
 
247
  async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
248
+ try:
249
+ encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
250
+ input_length = encoded_input["input_ids"].shape[1]
251
+ max_length = model.config.max_length
252
+ remaining_tokens = max_length - input_length
253
+ if remaining_tokens <= 0:
254
+ yield ""
255
+ generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
256
+ def stop_criteria(input_ids, scores):
257
+ decoded_output = tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
258
+ return decoded_output in stop_sequences
259
+ stopping_criteria = StoppingCriteriaList([stop_criteria])
260
+ outputs = model.generate(
261
+ **encoded_input,
262
+ do_sample=generation_config.do_sample,
263
+ max_new_tokens=generation_config.max_new_tokens,
264
+ temperature=generation_config.temperature,
265
+ top_p=generation_config.top_p,
266
+ top_k=generation_config.top_k,
267
+ repetition_penalty=generation_config.repetition_penalty,
268
+ num_return_sequences=generation_config.num_return_sequences,
269
+ stopping_criteria=stopping_criteria,
270
+ output_scores=True,
271
+ return_dict_in_generate=True
272
+ )
273
+ for output in outputs.sequences:
274
+ for token_id in output:
275
+ token = tokenizer.decode(token_id, skip_special_tokens=True)
276
+ yield token
277
+ except Exception as e:
278
+ yield f"Error during text generation: {e}"
279
 
280
 
281
  model_loader = S3ModelLoader(S3_BUCKET_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
282
 
283
  def get_model_data(request: GenerateRequest):
284
+ return model_loader.load_model_and_tokenizer(request.model_id, request.task_type)
285
 
286
  async def verify_api_key(api_key: str = Depends(api_key_header)):
287
  if api_key != API_KEY:
 
289
 
290
 
291
  @app.post("/generate", dependencies=[Depends(verify_api_key)])
292
+ async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, model_data=Depends(get_model_data)):
293
  try:
294
  device = "cuda" if torch.cuda.is_available() else "cpu"
295
  if request.task_type == "text":
 
304
  do_sample=request.do_sample,
305
  num_return_sequences=request.num_return_sequences,
306
  )
307
+ return StreamingResponse(stream_text(model, tokenizer, request.input_text, generation_config, request.stop_sequences, device, request.chunk_delay), media_type="text/plain")
 
 
 
308
  elif request.task_type in ["image", "audio", "video"]:
309
+ pipeline_func = model_data["pipeline"]
310
+ try:
311
+ result = pipeline_func(request.input_text)
312
+ if request.task_type == "image":
313
+ image = result[0]
314
+ img_byte_arr = BytesIO()
315
+ image.save(img_byte_arr, format="PNG")
316
+ img_byte_arr.seek(0)
317
+ return StreamingResponse(img_byte_arr, media_type="image/png")
318
+ elif request.task_type == "audio":
319
+ audio = result[0]
320
+ audio_byte_arr = BytesIO()
321
+ audio.save(audio_byte_arr, format="wav")
322
+ audio_byte_arr.seek(0)
323
+ return StreamingResponse(audio_byte_arr, media_type="audio/wav")
324
+ elif request.task_type == "video":
325
+ video = result[0]
326
+ video_byte_arr = BytesIO()
327
+ video.save(video_byte_arr, format="mp4")
328
+ video_byte_arr.seek(0)
329
+ return StreamingResponse(video_byte_arr, media_type="video/mp4")
330
+ except Exception as e:
331
+ raise HTTPException(status_code=500, detail=f"Error processing {request.task_type}: {e}")
332
  elif request.task_type == "classification":
333
  if request.image_file is None:
334
  raise HTTPException(status_code=400, detail="Image file is required for classification.")
 
368
  if request.audio_file is None:
369
  raise HTTPException(status_code=400, detail="Audio file is required for speech-to-text.")
370
  contents = await request.audio_file.read()
371
+ pipeline_func = model_data["pipeline"]
372
  try:
373
+ transcription = pipeline_func(contents, sampling_rate=16000)[0]["text"]
374
  return JSONResponse({"transcription": transcription})
375
  except Exception as e:
376
  raise HTTPException(status_code=500, detail=f"Error during speech-to-text: {str(e)}")
 
377
  elif request.task_type == "text-to-speech":
378
  if not request.input_text:
379
  raise HTTPException(status_code=400, detail="Input text is required for text-to-speech.")
380
+ pipeline_func = model_data["pipeline"]
381
  try:
382
+ audio = pipeline_func(request.input_text)[0]
383
  file_path = os.path.join(TEMP_DIR, f"{uuid.uuid4()}.wav")
384
  audio.save(file_path)
385
  background_tasks.add_task(os.remove, file_path)
386
  return FileResponse(file_path, media_type="audio/wav")
387
  except Exception as e:
388
  raise HTTPException(status_code=500, detail=f"Error during text-to-speech: {str(e)}")
 
389
  elif request.task_type == "image-segmentation":
390
  if request.image_file is None:
391
  raise HTTPException(status_code=400, detail="Image file is required for image segmentation.")
392
  contents = await request.image_file.read()
393
  image = Image.open(BytesIO(contents)).convert("RGB")
394
+ pipeline_func = model_data["pipeline"]
395
+ try:
396
+ result = pipeline_func(image)
397
+ mask = result[0]['mask']
398
+ mask_byte_arr = BytesIO()
399
+ mask.save(mask_byte_arr, format="PNG")
400
+ mask_byte_arr.seek(0)
401
+ return StreamingResponse(mask_byte_arr, media_type="image/png")
402
+ except Exception as e:
403
+ raise HTTPException(status_code=500, detail=f"Error during image segmentation: {e}")
404
  elif request.task_type == "feature-extraction":
405
  if request.raw_input is None:
406
  raise HTTPException(status_code=400, detail="raw_input is required for feature extraction.")
 
413
  inputs = feature_extractor(images=image, return_tensors="pt")
414
  else:
415
  raise ValueError("Unsupported raw_input type.")
416
+ features = inputs.pixel_values
417
  return JSONResponse({"features": features.tolist()})
418
  except Exception as fe:
419
  raise HTTPException(status_code=400, detail=f"Error during feature extraction: {fe}")
 
447
  image_contents = await request.image_file.read()
448
  mask_contents = await request.mask_image.read()
449
  image = Image.open(BytesIO(image_contents)).convert("RGB")
450
+ mask = Image.open(BytesIO(mask_contents)).convert("L")
451
+ pipeline_func = model_data["pipeline"]
452
+ try:
453
+ result = pipeline_func(image, mask)
454
+ inpainted_image = result[0]
455
+ img_byte_arr = BytesIO()
456
+ inpainted_image.save(img_byte_arr, format="PNG")
457
+ img_byte_arr.seek(0)
458
+ return StreamingResponse(img_byte_arr, media_type="image/png")
459
+ except Exception as e:
460
+ raise HTTPException(status_code=500, detail=f"Error during image inpainting: {e}")
461
  elif request.task_type == "image-super-resolution":
462
  if request.low_res_image is None:
463
  raise HTTPException(status_code=400, detail="low_res_image is required for image super-resolution.")
464
  contents = await request.low_res_image.read()
465
  image = Image.open(BytesIO(contents)).convert("RGB")
466
+ pipeline_func = model_data["pipeline"]
467
+ try:
468
+ result = pipeline_func(image)
469
+ upscaled_image = result[0]
470
+ img_byte_arr = BytesIO()
471
+ upscaled_image.save(img_byte_arr, format="PNG")
472
+ img_byte_arr.seek(0)
473
+ return StreamingResponse(img_byte_arr, media_type="image/png")
474
+ except Exception as e:
475
+ raise HTTPException(status_code=500, detail=f"Error during image super-resolution: {e}")
476
  elif request.task_type == "object-detection":
477
  if request.image_file is None:
478
  raise HTTPException(status_code=400, detail="Image file is required for object detection.")
479
  contents = await request.image_file.read()
480
  image = Image.open(BytesIO(contents)).convert("RGB")
481
+ pipeline_func = model_data["pipeline"]
482
  image_processor = model_data["image_processor"]
483
  inputs = image_processor(images=image, return_tensors="pt")
484
  with torch.no_grad():
485
+ try:
486
+ outputs = pipeline_func(image)
487
+ detections = outputs
488
+ return JSONResponse({"detections": detections})
489
+ except Exception as e:
490
+ raise HTTPException(status_code=500, detail=f"Error during object detection: {e}")
491
  elif request.task_type == "image-captioning":
492
  if request.image_file is None:
493
  raise HTTPException(status_code=400, detail="Image file is required for image captioning.")
494
  contents = await request.image_file.read()
495
  image = Image.open(BytesIO(contents)).convert("RGB")
496
+ pipeline_func = model_data["pipeline"]
497
+ try:
498
+ caption = pipeline_func(image)[0]['generated_text']
499
+ return JSONResponse({"caption": caption})
500
+ except Exception as e:
501
+ raise HTTPException(status_code=500, detail=f"Error during image captioning: {e}")
502
  elif request.task_type == "audio-transcription":
503
  if request.audio_file is None:
504
  raise HTTPException(status_code=400, detail="Audio file is required for audio transcription.")
505
+ contents = await request.audio_file.read()
506
+ pipeline_func = model_data["pipeline"]
507
  try:
508
+ transcription = pipeline_func(contents, sampling_rate=16000)[0]["text"]
509
+ return JSONResponse({"transcription": transcription})
 
 
 
 
 
510
  except Exception as e:
511
+ raise HTTPException(status_code=500, detail=f"Error during audio transcription: {str(e)}")
512
  elif request.task_type == "summarization":
513
  if request.input_text is None:
514
  raise HTTPException(status_code=400, detail="Input text is required for summarization.")
515
  model = model_data["model"].to(device)
516
  tokenizer = model_data["tokenizer"]
517
+ inputs = tokenizer(request.input_text, return_tensors="pt", truncation=True, max_length=512)
518
  with torch.no_grad():
519
+ try:
520
+ outputs = model.generate(**inputs)
521
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
522
+ return JSONResponse({"summary": summary})
523
+ except Exception as e:
524
+ raise HTTPException(status_code=500, detail=f"Error during summarization: {e}")
525
  else:
526
  raise HTTPException(status_code=500, detail=f"Unsupported task type")
527
  except Exception as e:
 
537
  async def health_check():
538
  return {"status": "healthy"}
539
 
540
+ class Token(BaseModel):
541
+ access_token: str
542
+ token_type: str
543
 
544
  @app.post("/token", response_model=Token)
545
  async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
546
  user = authenticate_user(form_data.username, form_data.password)
547
  if not user:
548
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"})
 
 
 
 
549
  access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
550
  access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)
551
  return {"access_token": access_token, "token_type": "bearer"}
552
 
553
  def authenticate_user(username: str, password: str):
554
+ cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
555
+ user = cursor.fetchone()
556
+ if user and pwd_context.verify(password, user[2]):
557
+ return {"username": username}
558
  return None
559
 
560
  def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None):
 
567
  encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
568
  return encoded_jwt
569
 
 
 
 
 
570
 
571
  @app.get("/users/me")
572
  async def read_users_me(current_user: str = Depends(get_current_user)):
573
  return {"username": current_user}
574
 
575
  async def get_current_user(token: str = Depends(oauth2_scheme)):
576
+ credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"})
 
 
 
 
577
  try:
578
  payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
579
  username: str = payload.get("sub")
580
  if username is None:
581
  raise credentials_exception
 
582
  except JWTError:
583
  raise credentials_exception
584
+ cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
585
+ user = cursor.fetchone()
586
  if user is None:
587
  raise credentials_exception
588
  return username
 
592
  async def create_user(user: User):
593
  try:
594
  hashed_password = pwd_context.hash(user.password)
595
+ cursor.execute("INSERT INTO users (username, email, hashed_password) VALUES (?, ?, ?)", (user.username, user.email, hashed_password))
596
+ conn.commit()
597
+ return user
598
+ except sqlite3.IntegrityError:
599
+ raise HTTPException(status_code=400, detail="Username or email already exists")
 
600
  except Exception as e:
601
  logger.error(f"Error creating user: {e}")
602
  raise HTTPException(status_code=500, detail=f"Error creating user: {e}")
 
606
  async def update_user_data(username: str, user: User):
607
  try:
608
  hashed_password = pwd_context.hash(user.password)
609
+ cursor.execute("UPDATE users SET email = ?, hashed_password = ? WHERE username = ?", (user.email, hashed_password, username))
610
+ conn.commit()
611
+ return user
 
 
 
 
612
  except Exception as e:
613
  logger.error(f"Error updating user: {e}")
614
  raise HTTPException(status_code=500, detail="Error updating user.")
615
 
616
 
 
617
  @app.delete("/users/{username}", dependencies=[Depends(get_current_user)])
618
  async def delete_user_account(username: str):
619
  try:
620
+ cursor.execute("DELETE FROM users WHERE username = ?", (username,))
621
+ conn.commit()
622
+ return JSONResponse({"message": "User deleted successfully."}, status_code=200)
 
 
623
  except Exception as e:
624
  logger.error(f"Error deleting user: {e}")
625
  raise HTTPException(status_code=500, detail="Error deleting user.")
 
627
 
628
  @app.get("/users", dependencies=[Depends(get_current_user)])
629
  async def get_all_users_route():
630
+ cursor.execute("SELECT username, email FROM users")
631
+ users = cursor.fetchall()
632
+ return [{"username": user[0], "email": user[1]} for user in users]
633
 
634
 
635
  @app.exception_handler(RequestValidationError)
636
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
637
+ return JSONResponse(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content=json.dumps({"detail": exc.errors(), "body": exc.body}))
 
 
 
 
638
 
639
  if __name__ == "__main__":
 
 
 
640
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)