import logging |
import os |
import time |
import docker |
import pytest |
from docker import DockerClient |
from pytest_docker.plugin import get_docker_ip |
from fastapi.testclient import TestClient |
from sqlalchemy import text, create_engine |
log = logging.getLogger(__name__) |
def get_fast_api_client(): |
from main import app |
with TestClient(app) as c: |
return c |
class AbstractIntegrationTest: |
BASE_PATH = None |
def create_url(self, path="", query_params=None): |
if self.BASE_PATH is None: |
raise Exception("BASE_PATH is not set") |
parts = self.BASE_PATH.split("/") |
parts = [part.strip() for part in parts if part.strip() != ""] |
path_parts = path.split("/") |
path_parts = [part.strip() for part in path_parts if part.strip() != ""] |
query_parts = "" |
if query_params: |
query_parts = "&".join( |
[f"{key}={value}" for key, value in query_params.items()] |
) |
query_parts = f"?{query_parts}" |
return "/".join(parts + path_parts) + query_parts |
@classmethod |
def setup_class(cls): |
pass |
def setup_method(self): |
pass |
@classmethod |
def teardown_class(cls): |
pass |
def teardown_method(self): |
pass |
class AbstractPostgresTest(AbstractIntegrationTest): |
DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted" |
docker_client: DockerClient |
@classmethod |
def _create_db_url(cls, env_vars_postgres: dict) -> str: |
host = get_docker_ip() |
user = env_vars_postgres["POSTGRES_USER"] |
pw = env_vars_postgres["POSTGRES_PASSWORD"] |
port = 8081 |
db = env_vars_postgres["POSTGRES_DB"] |
return f"postgresql://{user}:{pw}@{host}:{port}/{db}" |
@classmethod |
def setup_class(cls): |
super().setup_class() |
try: |
env_vars_postgres = { |
"POSTGRES_USER": "user", |
"POSTGRES_PASSWORD": "example", |
"POSTGRES_DB": "openwebui", |
} |
cls.docker_client = docker.from_env() |
cls.docker_client.containers.run( |
"postgres:16.2", |
detach=True, |
environment=env_vars_postgres, |
ports={5432: ("", 8081)}, |
command="postgres -c log_statement=all", |
) |
time.sleep(0.5) |
database_url = cls._create_db_url(env_vars_postgres) |
os.environ["DATABASE_URL"] = database_url |
retries = 10 |
db = None |
while retries > 0: |
try: |
from open_webui.config import OPEN_WEBUI_DIR |
db = create_engine(database_url, pool_pre_ping=True) |
db = db.connect() |
log.info("postgres is ready!") |
break |
except Exception as e: |
log.warning(e) |
time.sleep(3) |
retries -= 1 |
if db: |
cls.fast_api_client = get_fast_api_client() |
db.close() |
else: |
raise Exception("Could not connect to Postgres") |
except Exception as ex: |
log.error(ex) |
cls.teardown_class() |
pytest.fail(f"Could not setup test environment: {ex}") |
def _check_db_connection(self): |
from open_webui.apps.webui.internal.db import Session |
retries = 10 |
while retries > 0: |
try: |
Session.execute(text("SELECT 1")) |
Session.commit() |
break |
except Exception as e: |
Session.rollback() |
log.warning(e) |
time.sleep(3) |
retries -= 1 |
def setup_method(self): |
super().setup_method() |
self._check_db_connection() |
@classmethod |
def teardown_class(cls) -> None: |
super().teardown_class() |
cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) |
def teardown_method(self): |
from open_webui.apps.webui.internal.db import Session |
Session.commit() |
tables = [ |
"auth", |
"chat", |
"chatidtag", |
"document", |
"memory", |
"model", |
"prompt", |
"tag", |
'"user"', |
] |
for table in tables: |
Session.execute(text(f"TRUNCATE TABLE {table}")) |
Session.commit() |