File size: 3,080 Bytes
da70b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import sqlite3
from sqlite3 import Error
from pydantic import BaseModel

class User(BaseModel):
    username: str
    email: str
    hashed_password: str


def create_connection(db_file):
    """ create a database connection to a SQLite database """
    conn = None
    try:
        conn = sqlite3.connect(db_file)
        return conn
    except Error as e:
        print(e)
    return conn


def create_table(conn):
    """ create a table in the database """
    sql = """ CREATE TABLE IF NOT EXISTS users (
                                        id integer PRIMARY KEY,
                                        username text NOT NULL UNIQUE,
                                        email text NOT NULL UNIQUE,
                                        hashed_password text NOT NULL
                                    ); """
    try:
        cursor = conn.cursor()
        cursor.execute(sql)
    except Error as e:
        print(e)


def create_db_and_table():
    database = "users.db"
    conn = create_connection(database)
    if conn is not None:
        create_table(conn)
        conn.close()
    else:
        print("Error! cannot create the database connection.")


def insert_user(user_data: dict):
    conn = create_connection('users.db')
    cur = conn.cursor()
    sql = """ INSERT INTO users(username,email,hashed_password)
              VALUES(?,?,?) """
    cur.execute(sql, (user_data['username'], user_data['email'], user_data['hashed_password']))
    conn.commit()
    last_row_id = cur.lastrowid
    conn.close()
    if last_row_id:
        # Fetch the inserted user for consistency
        return get_user(user_data['username'])
    else:
        return None


def get_user(username: str):
    conn = create_connection('users.db')
    cur = conn.cursor()
    cur.execute("SELECT * FROM users WHERE username=?", (username,))
    user = cur.fetchone()
    conn.close()
    if user:
        return User(username=user[1], email=user[2], hashed_password=user[3])
    else:
        return None


def delete_user(username: str):
    conn = create_connection('users.db')
    cur = conn.cursor()
    cur.execute("DELETE FROM users WHERE username=?", (username,))
    conn.commit()
    rows_affected = cur.rowcount
    conn.close()
    return rows_affected > 0


def update_user(username: str, updated_data: dict):
    conn = create_connection('users.db')
    cur = conn.cursor()
    sql = """ UPDATE users
              SET email = ?,
                  hashed_password = ?
              WHERE username = ?"""
    cur.execute(sql, (updated_data['email'], updated_data['hashed_password'], username))
    conn.commit()
    rows_affected = cur.rowcount
    conn.close()
    if rows_affected > 0:
        # Fetch the updated user for consistency
        return get_user(username)
    else:
        return None

def get_all_users():
    conn = create_connection('users.db')
    cur = conn.cursor()
    cur.execute("SELECT * FROM users")
    users = cur.fetchall()
    conn.close()
    return [User(username=user[1], email=user[2], hashed_password=user[3]) for user in users]