aws_test / database.py
Hjgugugjhuhjggg's picture
Create database.py
da70b73 verified
raw
history blame
3.08 kB
import sqlite3
from sqlite3 import Error
from pydantic import BaseModel
class User(BaseModel):
username: str
email: str
hashed_password: str
def create_connection(db_file):
""" create a database connection to a SQLite database """
conn = None
try:
conn = sqlite3.connect(db_file)
return conn
except Error as e:
print(e)
return conn
def create_table(conn):
""" create a table in the database """
sql = """ CREATE TABLE IF NOT EXISTS users (
id integer PRIMARY KEY,
username text NOT NULL UNIQUE,
email text NOT NULL UNIQUE,
hashed_password text NOT NULL
); """
try:
cursor = conn.cursor()
cursor.execute(sql)
except Error as e:
print(e)
def create_db_and_table():
database = "users.db"
conn = create_connection(database)
if conn is not None:
create_table(conn)
conn.close()
else:
print("Error! cannot create the database connection.")
def insert_user(user_data: dict):
conn = create_connection('users.db')
cur = conn.cursor()
sql = """ INSERT INTO users(username,email,hashed_password)
VALUES(?,?,?) """
cur.execute(sql, (user_data['username'], user_data['email'], user_data['hashed_password']))
conn.commit()
last_row_id = cur.lastrowid
conn.close()
if last_row_id:
# Fetch the inserted user for consistency
return get_user(user_data['username'])
else:
return None
def get_user(username: str):
conn = create_connection('users.db')
cur = conn.cursor()
cur.execute("SELECT * FROM users WHERE username=?", (username,))
user = cur.fetchone()
conn.close()
if user:
return User(username=user[1], email=user[2], hashed_password=user[3])
else:
return None
def delete_user(username: str):
conn = create_connection('users.db')
cur = conn.cursor()
cur.execute("DELETE FROM users WHERE username=?", (username,))
conn.commit()
rows_affected = cur.rowcount
conn.close()
return rows_affected > 0
def update_user(username: str, updated_data: dict):
conn = create_connection('users.db')
cur = conn.cursor()
sql = """ UPDATE users
SET email = ?,
hashed_password = ?
WHERE username = ?"""
cur.execute(sql, (updated_data['email'], updated_data['hashed_password'], username))
conn.commit()
rows_affected = cur.rowcount
conn.close()
if rows_affected > 0:
# Fetch the updated user for consistency
return get_user(username)
else:
return None
def get_all_users():
conn = create_connection('users.db')
cur = conn.cursor()
cur.execute("SELECT * FROM users")
users = cur.fetchall()
conn.close()
return [User(username=user[1], email=user[2], hashed_password=user[3]) for user in users]