Hjgugugjhuhjggg commited on
Commit
2b9f02a
·
verified ·
1 Parent(s): a42c4c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -60
app.py CHANGED
@@ -42,6 +42,7 @@ from fastapi.templating import Jinja2Templates
42
  from fastapi.middleware.gzip import GZipMiddleware
43
  from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
44
  from starlette.middleware.cors import CORSMiddleware
 
45
 
46
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
47
  logger = logging.getLogger(__name__)
@@ -54,17 +55,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
54
 
55
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
56
 
57
- conn = sqlite3.connect('users.db', check_same_thread=False)
58
- cursor = conn.cursor()
59
- cursor.execute('''
60
- CREATE TABLE IF NOT EXISTS users (
61
- username TEXT PRIMARY KEY,
62
- email TEXT UNIQUE,
63
- hashed_password TEXT
64
- )
65
- ''')
66
- conn.commit()
67
-
68
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
69
  API_KEY = os.getenv("API_KEY")
70
  api_key_header = APIKeyHeader(name="X-API-Key")
@@ -77,6 +67,7 @@ HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
77
  TEMP_DIR = "/tmp"
78
  STATIC_DIR = "static"
79
  TEMPLATES = Jinja2Templates(directory="templates")
 
80
 
81
  app = FastAPI()
82
  app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
@@ -568,73 +559,56 @@ class Token(BaseModel):
568
  access_token: str
569
  token_type: str
570
 
 
 
 
 
 
 
 
 
 
 
 
571
  @app.post("/token", response_model=Token)
572
- async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
573
- user = authenticate_user(form_data.username, form_data.password)
574
  if not user:
575
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"})
576
  access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
577
  access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)
578
  return {"access_token": access_token, "token_type": "bearer"}
579
 
580
- def authenticate_user(username: str, password: str):
581
- cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
582
- user = cursor.fetchone()
583
- if user and pwd_context.verify(password, user[2]):
584
- return {"username": username}
585
- return None
586
-
587
- def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None):
588
- to_encode = data.copy()
589
- if expires_delta:
590
- expire = datetime.utcnow() + expires_delta
591
- else:
592
- expire = datetime.utcnow() + timedelta(minutes=15)
593
- to_encode.update({"exp": expire})
594
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
595
- return encoded_jwt
596
-
597
 
598
- @app.get("/users/me")
599
- async def read_users_me(current_user: str = Depends(get_current_user)):
600
- return {"username": current_user}
601
-
602
- async def get_current_user(token: str = Depends(oauth2_scheme)):
603
  credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"})
604
  try:
605
  payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
606
  username: str = payload.get("sub")
607
  if username is None:
608
  raise credentials_exception
 
 
 
 
609
  except JWTError:
610
  raise credentials_exception
611
- cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
612
- user = cursor.fetchone()
613
- if user is None:
614
- raise credentials_exception
615
- return username
616
-
617
 
618
  @app.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
619
- async def create_user(user: User):
 
620
  try:
621
- hashed_password = pwd_context.hash(user.password)
622
- cursor.execute("INSERT INTO users (username, email, hashed_password) VALUES (?, ?, ?)", (user.username, user.email, hashed_password))
623
- conn.commit()
624
  return user
625
- except sqlite3.IntegrityError:
626
  raise HTTPException(status_code=400, detail="Username or email already exists")
627
- except Exception as e:
628
- logger.error(f"Error creating user: {e}")
629
- raise HTTPException(status_code=500, detail=f"Error creating user: {e}")
630
 
631
 
632
  @app.put("/users/{username}", response_model=User, dependencies=[Depends(get_current_user)])
633
- async def update_user_data(username: str, user: User):
 
634
  try:
635
- hashed_password = pwd_context.hash(user.password)
636
- cursor.execute("UPDATE users SET email = ?, hashed_password = ? WHERE username = ?", (user.email, hashed_password, username))
637
- conn.commit()
638
  return user
639
  except Exception as e:
640
  logger.error(f"Error updating user: {e}")
@@ -642,10 +616,9 @@ async def update_user_data(username: str, user: User):
642
 
643
 
644
  @app.delete("/users/{username}", dependencies=[Depends(get_current_user)])
645
- async def delete_user_account(username: str):
646
  try:
647
- cursor.execute("DELETE FROM users WHERE username = ?", (username,))
648
- conn.commit()
649
  return JSONResponse({"message": "User deleted successfully."}, status_code=200)
650
  except Exception as e:
651
  logger.error(f"Error deleting user: {e}")
@@ -653,10 +626,17 @@ async def delete_user_account(username: str):
653
 
654
 
655
  @app.get("/users", dependencies=[Depends(get_current_user)])
656
- async def get_all_users_route():
657
- cursor.execute("SELECT username, email FROM users")
658
- users = cursor.fetchall()
659
- return [{"username": user[0], "email": user[1]} for user in users]
 
 
 
 
 
 
 
660
 
661
 
662
  @app.exception_handler(RequestValidationError)
@@ -667,5 +647,16 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
667
  )
668
 
669
 
 
 
 
 
 
 
 
 
 
 
 
670
  if __name__ == "__main__":
671
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
 
42
  from fastapi.middleware.gzip import GZipMiddleware
43
  from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
44
  from starlette.middleware.cors import CORSMiddleware
45
+ import asyncpg
46
 
47
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
48
  logger = logging.getLogger(__name__)
 
55
 
56
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
57
 
 
 
 
 
 
 
 
 
 
 
 
58
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
59
  API_KEY = os.getenv("API_KEY")
60
  api_key_header = APIKeyHeader(name="X-API-Key")
 
67
  TEMP_DIR = "/tmp"
68
  STATIC_DIR = "static"
69
  TEMPLATES = Jinja2Templates(directory="templates")
70
+ DATABASE_URL = os.getenv("DATABASE_URL")
71
 
72
  app = FastAPI()
73
  app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
 
559
  access_token: str
560
  token_type: str
561
 
562
+ async def get_db():
563
+ async with asyncpg.create_pool(DATABASE_URL) as pool:
564
+ async with pool.acquire() as conn:
565
+ yield conn
566
+
567
+ async def authenticate_user(username, password, conn):
568
+ row = await conn.fetchrow("SELECT * FROM users WHERE username = $1", username)
569
+ if row is not None and pwd_context.verify(password, row["hashed_password"]):
570
+ return {"username": username}
571
+ return None
572
+
573
  @app.post("/token", response_model=Token)
574
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), conn = Depends(get_db)):
575
+ user = await authenticate_user(form_data.username, form_data.password, conn)
576
  if not user:
577
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"})
578
  access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
579
  access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)
580
  return {"access_token": access_token, "token_type": "bearer"}
581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
+ async def get_current_user(token: str = Depends(oauth2_scheme), conn = Depends(get_db)):
 
 
 
 
584
  credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"})
585
  try:
586
  payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
587
  username: str = payload.get("sub")
588
  if username is None:
589
  raise credentials_exception
590
+ user = await conn.fetchrow("SELECT * FROM users WHERE username = $1", username)
591
+ if user is None:
592
+ raise credentials_exception
593
+ return username
594
  except JWTError:
595
  raise credentials_exception
 
 
 
 
 
 
596
 
597
  @app.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
598
+ async def create_user(user: User, conn = Depends(get_db)):
599
+ hashed_password = pwd_context.hash(user.password)
600
  try:
601
+ await conn.execute("INSERT INTO users (username, email, hashed_password) VALUES ($1, $2, $3)", user.username, user.email, hashed_password)
 
 
602
  return user
603
+ except asyncpg.exceptions.UniqueViolationError:
604
  raise HTTPException(status_code=400, detail="Username or email already exists")
 
 
 
605
 
606
 
607
  @app.put("/users/{username}", response_model=User, dependencies=[Depends(get_current_user)])
608
+ async def update_user_data(username: str, user: User, conn = Depends(get_db)):
609
+ hashed_password = pwd_context.hash(user.password)
610
  try:
611
+ await conn.execute("UPDATE users SET email = $1, hashed_password = $2 WHERE username = $3", user.email, hashed_password, username)
 
 
612
  return user
613
  except Exception as e:
614
  logger.error(f"Error updating user: {e}")
 
616
 
617
 
618
  @app.delete("/users/{username}", dependencies=[Depends(get_current_user)])
619
+ async def delete_user_account(username: str, conn = Depends(get_db)):
620
  try:
621
+ await conn.execute("DELETE FROM users WHERE username = $1", username)
 
622
  return JSONResponse({"message": "User deleted successfully."}, status_code=200)
623
  except Exception as e:
624
  logger.error(f"Error deleting user: {e}")
 
626
 
627
 
628
  @app.get("/users", dependencies=[Depends(get_current_user)])
629
+ async def get_all_users_route(conn = Depends(get_db)):
630
+ rows = await conn.fetch("SELECT username, email FROM users")
631
+ return [{"username": row["username"], "email": row["email"]} for row in rows]
632
+
633
+
634
+ @app.get("/users/me", dependencies=[Depends(get_current_user)]) # Requires authentication
635
+ async def read_users_me(current_user: str = Depends(get_current_user), conn=Depends(get_db)):
636
+ user = await conn.fetchrow("SELECT username, email FROM users WHERE username = $1", current_user)
637
+ if user:
638
+ return {"username": user["username"], "email": user["email"]}
639
+ raise HTTPException(status_code=404, detail="User not found")
640
 
641
 
642
  @app.exception_handler(RequestValidationError)
 
647
  )
648
 
649
 
650
+ def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None):
651
+ to_encode = data.copy()
652
+ if expires_delta:
653
+ expire = datetime.utcnow() + expires_delta
654
+ else:
655
+ expire = datetime.utcnow() + timedelta(minutes=15)
656
+ to_encode.update({"exp": expire})
657
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
658
+ return encoded_jwt
659
+
660
+
661
  if __name__ == "__main__":
662
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)