Spaces:
Sleeping
Sleeping
import sqlite3 | |
from sqlite3 import Error | |
from pydantic import BaseModel | |
class User(BaseModel): | |
username: str | |
email: str | |
hashed_password: str | |
def create_connection(db_file): | |
""" create a database connection to a SQLite database """ | |
conn = None | |
try: | |
conn = sqlite3.connect(db_file) | |
return conn | |
except Error as e: | |
print(e) | |
return conn | |
def create_table(conn): | |
""" create a table in the database """ | |
sql = """ CREATE TABLE IF NOT EXISTS users ( | |
id integer PRIMARY KEY, | |
username text NOT NULL UNIQUE, | |
email text NOT NULL UNIQUE, | |
hashed_password text NOT NULL | |
); """ | |
try: | |
cursor = conn.cursor() | |
cursor.execute(sql) | |
except Error as e: | |
print(e) | |
def create_db_and_table(): | |
database = "users.db" | |
conn = create_connection(database) | |
if conn is not None: | |
create_table(conn) | |
conn.close() | |
else: | |
print("Error! cannot create the database connection.") | |
def insert_user(user_data: dict): | |
conn = create_connection('users.db') | |
cur = conn.cursor() | |
sql = """ INSERT INTO users(username,email,hashed_password) | |
VALUES(?,?,?) """ | |
cur.execute(sql, (user_data['username'], user_data['email'], user_data['hashed_password'])) | |
conn.commit() | |
last_row_id = cur.lastrowid | |
conn.close() | |
if last_row_id: | |
# Fetch the inserted user for consistency | |
return get_user(user_data['username']) | |
else: | |
return None | |
def get_user(username: str): | |
conn = create_connection('users.db') | |
cur = conn.cursor() | |
cur.execute("SELECT * FROM users WHERE username=?", (username,)) | |
user = cur.fetchone() | |
conn.close() | |
if user: | |
return User(username=user[1], email=user[2], hashed_password=user[3]) | |
else: | |
return None | |
def delete_user(username: str): | |
conn = create_connection('users.db') | |
cur = conn.cursor() | |
cur.execute("DELETE FROM users WHERE username=?", (username,)) | |
conn.commit() | |
rows_affected = cur.rowcount | |
conn.close() | |
return rows_affected > 0 | |
def update_user(username: str, updated_data: dict): | |
conn = create_connection('users.db') | |
cur = conn.cursor() | |
sql = """ UPDATE users | |
SET email = ?, | |
hashed_password = ? | |
WHERE username = ?""" | |
cur.execute(sql, (updated_data['email'], updated_data['hashed_password'], username)) | |
conn.commit() | |
rows_affected = cur.rowcount | |
conn.close() | |
if rows_affected > 0: | |
# Fetch the updated user for consistency | |
return get_user(username) | |
else: | |
return None | |
def get_all_users(): | |
conn = create_connection('users.db') | |
cur = conn.cursor() | |
cur.execute("SELECT * FROM users") | |
users = cur.fetchall() | |
conn.close() | |
return [User(username=user[1], email=user[2], hashed_password=user[3]) for user in users] |