|
import logging |
|
import os |
|
import time |
|
|
|
import docker |
|
import pytest |
|
from docker import DockerClient |
|
from pytest_docker.plugin import get_docker_ip |
|
from fastapi.testclient import TestClient |
|
from sqlalchemy import text, create_engine |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
def get_fast_api_client(): |
|
from main import app |
|
|
|
with TestClient(app) as c: |
|
return c |
|
|
|
|
|
class AbstractIntegrationTest: |
|
BASE_PATH = None |
|
|
|
def create_url(self, path="", query_params=None): |
|
if self.BASE_PATH is None: |
|
raise Exception("BASE_PATH is not set") |
|
parts = self.BASE_PATH.split("/") |
|
parts = [part.strip() for part in parts if part.strip() != ""] |
|
path_parts = path.split("/") |
|
path_parts = [part.strip() for part in path_parts if part.strip() != ""] |
|
query_parts = "" |
|
if query_params: |
|
query_parts = "&".join( |
|
[f"{key}={value}" for key, value in query_params.items()] |
|
) |
|
query_parts = f"?{query_parts}" |
|
return "/".join(parts + path_parts) + query_parts |
|
|
|
@classmethod |
|
def setup_class(cls): |
|
pass |
|
|
|
def setup_method(self): |
|
pass |
|
|
|
@classmethod |
|
def teardown_class(cls): |
|
pass |
|
|
|
def teardown_method(self): |
|
pass |
|
|
|
|
|
class AbstractPostgresTest(AbstractIntegrationTest): |
|
DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted" |
|
docker_client: DockerClient |
|
|
|
@classmethod |
|
def _create_db_url(cls, env_vars_postgres: dict) -> str: |
|
host = get_docker_ip() |
|
user = env_vars_postgres["POSTGRES_USER"] |
|
pw = env_vars_postgres["POSTGRES_PASSWORD"] |
|
port = 8081 |
|
db = env_vars_postgres["POSTGRES_DB"] |
|
return f"postgresql://{user}:{pw}@{host}:{port}/{db}" |
|
|
|
@classmethod |
|
def setup_class(cls): |
|
super().setup_class() |
|
try: |
|
env_vars_postgres = { |
|
"POSTGRES_USER": "user", |
|
"POSTGRES_PASSWORD": "example", |
|
"POSTGRES_DB": "openwebui", |
|
} |
|
cls.docker_client = docker.from_env() |
|
cls.docker_client.containers.run( |
|
"postgres:16.2", |
|
detach=True, |
|
environment=env_vars_postgres, |
|
name=cls.DOCKER_CONTAINER_NAME, |
|
ports={5432: ("0.0.0.0", 8081)}, |
|
command="postgres -c log_statement=all", |
|
) |
|
time.sleep(0.5) |
|
|
|
database_url = cls._create_db_url(env_vars_postgres) |
|
os.environ["DATABASE_URL"] = database_url |
|
retries = 10 |
|
db = None |
|
while retries > 0: |
|
try: |
|
from open_webui.config import OPEN_WEBUI_DIR |
|
|
|
db = create_engine(database_url, pool_pre_ping=True) |
|
db = db.connect() |
|
log.info("postgres is ready!") |
|
break |
|
except Exception as e: |
|
log.warning(e) |
|
time.sleep(3) |
|
retries -= 1 |
|
|
|
if db: |
|
|
|
cls.fast_api_client = get_fast_api_client() |
|
db.close() |
|
else: |
|
raise Exception("Could not connect to Postgres") |
|
except Exception as ex: |
|
log.error(ex) |
|
cls.teardown_class() |
|
pytest.fail(f"Could not setup test environment: {ex}") |
|
|
|
def _check_db_connection(self): |
|
from open_webui.apps.webui.internal.db import Session |
|
|
|
retries = 10 |
|
while retries > 0: |
|
try: |
|
Session.execute(text("SELECT 1")) |
|
Session.commit() |
|
break |
|
except Exception as e: |
|
Session.rollback() |
|
log.warning(e) |
|
time.sleep(3) |
|
retries -= 1 |
|
|
|
def setup_method(self): |
|
super().setup_method() |
|
self._check_db_connection() |
|
|
|
@classmethod |
|
def teardown_class(cls) -> None: |
|
super().teardown_class() |
|
cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) |
|
|
|
def teardown_method(self): |
|
from open_webui.apps.webui.internal.db import Session |
|
|
|
|
|
Session.commit() |
|
|
|
|
|
tables = [ |
|
"auth", |
|
"chat", |
|
"chatidtag", |
|
"document", |
|
"memory", |
|
"model", |
|
"prompt", |
|
"tag", |
|
'"user"', |
|
] |
|
for table in tables: |
|
Session.execute(text(f"TRUNCATE TABLE {table}")) |
|
Session.commit() |
|
|