|
import os, sys |
|
from pathlib import Path |
|
import aiosqlite |
|
import asyncio |
|
from typing import Optional, Tuple, Dict |
|
from contextlib import asynccontextmanager |
|
import logging |
|
import json |
|
from .utils import ensure_content_dirs, generate_content_hash |
|
from .models import CrawlResult, MarkdownGenerationResult |
|
import xxhash |
|
import aiofiles |
|
from .config import NEED_MIGRATION |
|
from .version_manager import VersionManager |
|
from .async_logger import AsyncLogger |
|
from .utils import get_error_context, create_box_message |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") |
|
os.makedirs(DB_PATH, exist_ok=True) |
|
DB_PATH = os.path.join(base_directory, "crawl4ai.db") |
|
|
|
class AsyncDatabaseManager: |
|
def __init__(self, pool_size: int = 10, max_retries: int = 3): |
|
self.db_path = DB_PATH |
|
self.content_paths = ensure_content_dirs(os.path.dirname(DB_PATH)) |
|
self.pool_size = pool_size |
|
self.max_retries = max_retries |
|
self.connection_pool: Dict[int, aiosqlite.Connection] = {} |
|
self.pool_lock = asyncio.Lock() |
|
self.init_lock = asyncio.Lock() |
|
self.connection_semaphore = asyncio.Semaphore(pool_size) |
|
self._initialized = False |
|
self.version_manager = VersionManager() |
|
self.logger = AsyncLogger( |
|
log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"), |
|
verbose=False, |
|
tag_width=10 |
|
) |
|
|
|
|
|
async def initialize(self): |
|
"""Initialize the database and connection pool""" |
|
try: |
|
self.logger.info("Initializing database", tag="INIT") |
|
|
|
os.makedirs(os.path.dirname(self.db_path), exist_ok=True) |
|
|
|
|
|
needs_update = self.version_manager.needs_update() |
|
|
|
|
|
await self.ainit_db() |
|
|
|
|
|
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: |
|
async with db.execute( |
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='crawled_data'" |
|
) as cursor: |
|
result = await cursor.fetchone() |
|
if not result: |
|
raise Exception("crawled_data table was not created") |
|
|
|
|
|
if needs_update: |
|
self.logger.info("New version detected, running updates", tag="INIT") |
|
await self.update_db_schema() |
|
from .migrations import run_migration |
|
await run_migration() |
|
self.version_manager.update_version() |
|
self.logger.success("Version update completed successfully", tag="COMPLETE") |
|
else: |
|
self.logger.success("Database initialization completed successfully", tag="COMPLETE") |
|
|
|
|
|
except Exception as e: |
|
self.logger.error( |
|
message="Database initialization error: {error}", |
|
tag="ERROR", |
|
params={"error": str(e)} |
|
) |
|
self.logger.info( |
|
message="Database will be initialized on first use", |
|
tag="INIT" |
|
) |
|
|
|
raise |
|
|
|
|
|
async def cleanup(self): |
|
"""Cleanup connections when shutting down""" |
|
async with self.pool_lock: |
|
for conn in self.connection_pool.values(): |
|
await conn.close() |
|
self.connection_pool.clear() |
|
|
|
@asynccontextmanager |
|
async def get_connection(self): |
|
"""Connection pool manager with enhanced error handling""" |
|
if not self._initialized: |
|
async with self.init_lock: |
|
if not self._initialized: |
|
try: |
|
await self.initialize() |
|
self._initialized = True |
|
except Exception as e: |
|
import sys |
|
error_context = get_error_context(sys.exc_info()) |
|
self.logger.error( |
|
message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}", |
|
tag="ERROR", |
|
force_verbose=True, |
|
params={ |
|
"error": str(e), |
|
"context": error_context["code_context"], |
|
"traceback": error_context["full_traceback"] |
|
} |
|
) |
|
raise |
|
|
|
await self.connection_semaphore.acquire() |
|
task_id = id(asyncio.current_task()) |
|
|
|
try: |
|
async with self.pool_lock: |
|
if task_id not in self.connection_pool: |
|
try: |
|
conn = await aiosqlite.connect( |
|
self.db_path, |
|
timeout=30.0 |
|
) |
|
await conn.execute('PRAGMA journal_mode = WAL') |
|
await conn.execute('PRAGMA busy_timeout = 5000') |
|
|
|
|
|
async with conn.execute("PRAGMA table_info(crawled_data)") as cursor: |
|
columns = await cursor.fetchall() |
|
column_names = [col[1] for col in columns] |
|
expected_columns = { |
|
'url', 'html', 'cleaned_html', 'markdown', 'extracted_content', |
|
'success', 'media', 'links', 'metadata', 'screenshot', |
|
'response_headers', 'downloaded_files' |
|
} |
|
missing_columns = expected_columns - set(column_names) |
|
if missing_columns: |
|
raise ValueError(f"Database missing columns: {missing_columns}") |
|
|
|
self.connection_pool[task_id] = conn |
|
except Exception as e: |
|
import sys |
|
error_context = get_error_context(sys.exc_info()) |
|
error_message = ( |
|
f"Unexpected error in db get_connection at line {error_context['line_no']} " |
|
f"in {error_context['function']} ({error_context['filename']}):\n" |
|
f"Error: {str(e)}\n\n" |
|
f"Code context:\n{error_context['code_context']}" |
|
) |
|
self.logger.error( |
|
message=create_box_message(error_message, type= "error"), |
|
) |
|
|
|
raise |
|
|
|
yield self.connection_pool[task_id] |
|
|
|
except Exception as e: |
|
import sys |
|
error_context = get_error_context(sys.exc_info()) |
|
error_message = ( |
|
f"Unexpected error in db get_connection at line {error_context['line_no']} " |
|
f"in {error_context['function']} ({error_context['filename']}):\n" |
|
f"Error: {str(e)}\n\n" |
|
f"Code context:\n{error_context['code_context']}" |
|
) |
|
self.logger.error( |
|
message=create_box_message(error_message, type= "error"), |
|
) |
|
raise |
|
finally: |
|
async with self.pool_lock: |
|
if task_id in self.connection_pool: |
|
await self.connection_pool[task_id].close() |
|
del self.connection_pool[task_id] |
|
self.connection_semaphore.release() |
|
|
|
|
|
async def execute_with_retry(self, operation, *args): |
|
"""Execute database operations with retry logic""" |
|
for attempt in range(self.max_retries): |
|
try: |
|
async with self.get_connection() as db: |
|
result = await operation(db, *args) |
|
await db.commit() |
|
return result |
|
except Exception as e: |
|
if attempt == self.max_retries - 1: |
|
self.logger.error( |
|
message="Operation failed after {retries} attempts: {error}", |
|
tag="ERROR", |
|
force_verbose=True, |
|
params={ |
|
"retries": self.max_retries, |
|
"error": str(e) |
|
} |
|
) |
|
raise |
|
await asyncio.sleep(1 * (attempt + 1)) |
|
|
|
async def ainit_db(self): |
|
"""Initialize database schema""" |
|
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: |
|
await db.execute(''' |
|
CREATE TABLE IF NOT EXISTS crawled_data ( |
|
url TEXT PRIMARY KEY, |
|
html TEXT, |
|
cleaned_html TEXT, |
|
markdown TEXT, |
|
extracted_content TEXT, |
|
success BOOLEAN, |
|
media TEXT DEFAULT "{}", |
|
links TEXT DEFAULT "{}", |
|
metadata TEXT DEFAULT "{}", |
|
screenshot TEXT DEFAULT "", |
|
response_headers TEXT DEFAULT "{}", |
|
downloaded_files TEXT DEFAULT "{}" -- New column added |
|
) |
|
''') |
|
await db.commit() |
|
|
|
|
|
|
|
async def update_db_schema(self): |
|
"""Update database schema if needed""" |
|
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: |
|
cursor = await db.execute("PRAGMA table_info(crawled_data)") |
|
columns = await cursor.fetchall() |
|
column_names = [column[1] for column in columns] |
|
|
|
|
|
new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files'] |
|
|
|
for column in new_columns: |
|
if column not in column_names: |
|
await self.aalter_db_add_column(column, db) |
|
await db.commit() |
|
|
|
async def aalter_db_add_column(self, new_column: str, db): |
|
"""Add new column to the database""" |
|
if new_column == 'response_headers': |
|
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"') |
|
else: |
|
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') |
|
self.logger.info( |
|
message="Added column '{column}' to the database", |
|
tag="INIT", |
|
params={"column": new_column} |
|
) |
|
|
|
|
|
async def aget_cached_url(self, url: str) -> Optional[CrawlResult]: |
|
"""Retrieve cached URL data as CrawlResult""" |
|
async def _get(db): |
|
async with db.execute( |
|
'SELECT * FROM crawled_data WHERE url = ?', (url,) |
|
) as cursor: |
|
row = await cursor.fetchone() |
|
if not row: |
|
return None |
|
|
|
|
|
columns = [description[0] for description in cursor.description] |
|
|
|
row_dict = dict(zip(columns, row)) |
|
|
|
|
|
content_fields = { |
|
'html': row_dict['html'], |
|
'cleaned_html': row_dict['cleaned_html'], |
|
'markdown': row_dict['markdown'], |
|
'extracted_content': row_dict['extracted_content'], |
|
'screenshot': row_dict['screenshot'], |
|
'screenshots': row_dict['screenshot'], |
|
} |
|
|
|
for field, hash_value in content_fields.items(): |
|
if hash_value: |
|
content = await self._load_content( |
|
hash_value, |
|
field.split('_')[0] |
|
) |
|
row_dict[field] = content or "" |
|
else: |
|
row_dict[field] = "" |
|
|
|
|
|
json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown'] |
|
for field in json_fields: |
|
try: |
|
row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {} |
|
except json.JSONDecodeError: |
|
row_dict[field] = {} |
|
|
|
if isinstance(row_dict['markdown'], Dict): |
|
row_dict['markdown_v2'] = row_dict['markdown'] |
|
if row_dict['markdown'].get('raw_markdown'): |
|
row_dict['markdown'] = row_dict['markdown']['raw_markdown'] |
|
|
|
|
|
try: |
|
row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else [] |
|
except json.JSONDecodeError: |
|
row_dict['downloaded_files'] = [] |
|
|
|
|
|
valid_fields = CrawlResult.__annotations__.keys() |
|
filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields} |
|
|
|
return CrawlResult(**filtered_dict) |
|
|
|
try: |
|
return await self.execute_with_retry(_get) |
|
except Exception as e: |
|
self.logger.error( |
|
message="Error retrieving cached URL: {error}", |
|
tag="ERROR", |
|
force_verbose=True, |
|
params={"error": str(e)} |
|
) |
|
return None |
|
|
|
async def acache_url(self, result: CrawlResult): |
|
"""Cache CrawlResult data""" |
|
|
|
content_map = { |
|
'html': (result.html, 'html'), |
|
'cleaned_html': (result.cleaned_html or "", 'cleaned'), |
|
'markdown': None, |
|
'extracted_content': (result.extracted_content or "", 'extracted'), |
|
'screenshot': (result.screenshot or "", 'screenshots') |
|
} |
|
|
|
try: |
|
if isinstance(result.markdown, MarkdownGenerationResult): |
|
content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown') |
|
elif hasattr(result, 'markdown_v2'): |
|
content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown') |
|
elif isinstance(result.markdown, str): |
|
markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown) |
|
content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown') |
|
else: |
|
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') |
|
except Exception as e: |
|
self.logger.warning( |
|
message=f"Error processing markdown content: {str(e)}", |
|
tag="WARNING" |
|
) |
|
|
|
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') |
|
|
|
content_hashes = {} |
|
for field, (content, content_type) in content_map.items(): |
|
content_hashes[field] = await self._store_content(content, content_type) |
|
|
|
async def _cache(db): |
|
await db.execute(''' |
|
INSERT INTO crawled_data ( |
|
url, html, cleaned_html, markdown, |
|
extracted_content, success, media, links, metadata, |
|
screenshot, response_headers, downloaded_files |
|
) |
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
|
ON CONFLICT(url) DO UPDATE SET |
|
html = excluded.html, |
|
cleaned_html = excluded.cleaned_html, |
|
markdown = excluded.markdown, |
|
extracted_content = excluded.extracted_content, |
|
success = excluded.success, |
|
media = excluded.media, |
|
links = excluded.links, |
|
metadata = excluded.metadata, |
|
screenshot = excluded.screenshot, |
|
response_headers = excluded.response_headers, |
|
downloaded_files = excluded.downloaded_files |
|
''', ( |
|
result.url, |
|
content_hashes['html'], |
|
content_hashes['cleaned_html'], |
|
content_hashes['markdown'], |
|
content_hashes['extracted_content'], |
|
result.success, |
|
json.dumps(result.media), |
|
json.dumps(result.links), |
|
json.dumps(result.metadata or {}), |
|
content_hashes['screenshot'], |
|
json.dumps(result.response_headers or {}), |
|
json.dumps(result.downloaded_files or []) |
|
)) |
|
|
|
try: |
|
await self.execute_with_retry(_cache) |
|
except Exception as e: |
|
self.logger.error( |
|
message="Error caching URL: {error}", |
|
tag="ERROR", |
|
force_verbose=True, |
|
params={"error": str(e)} |
|
) |
|
|
|
|
|
async def aget_total_count(self) -> int: |
|
"""Get total number of cached URLs""" |
|
async def _count(db): |
|
async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor: |
|
result = await cursor.fetchone() |
|
return result[0] if result else 0 |
|
|
|
try: |
|
return await self.execute_with_retry(_count) |
|
except Exception as e: |
|
self.logger.error( |
|
message="Error getting total count: {error}", |
|
tag="ERROR", |
|
force_verbose=True, |
|
params={"error": str(e)} |
|
) |
|
return 0 |
|
|
|
async def aclear_db(self): |
|
"""Clear all data from the database""" |
|
async def _clear(db): |
|
await db.execute('DELETE FROM crawled_data') |
|
|
|
try: |
|
await self.execute_with_retry(_clear) |
|
except Exception as e: |
|
self.logger.error( |
|
message="Error clearing database: {error}", |
|
tag="ERROR", |
|
force_verbose=True, |
|
params={"error": str(e)} |
|
) |
|
|
|
async def aflush_db(self): |
|
"""Drop the entire table""" |
|
async def _flush(db): |
|
await db.execute('DROP TABLE IF EXISTS crawled_data') |
|
|
|
try: |
|
await self.execute_with_retry(_flush) |
|
except Exception as e: |
|
self.logger.error( |
|
message="Error flushing database: {error}", |
|
tag="ERROR", |
|
force_verbose=True, |
|
params={"error": str(e)} |
|
) |
|
|
|
|
|
async def _store_content(self, content: str, content_type: str) -> str: |
|
"""Store content in filesystem and return hash""" |
|
if not content: |
|
return "" |
|
|
|
content_hash = generate_content_hash(content) |
|
file_path = os.path.join(self.content_paths[content_type], content_hash) |
|
|
|
|
|
if not os.path.exists(file_path): |
|
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: |
|
await f.write(content) |
|
|
|
return content_hash |
|
|
|
async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]: |
|
"""Load content from filesystem by hash""" |
|
if not content_hash: |
|
return None |
|
|
|
file_path = os.path.join(self.content_paths[content_type], content_hash) |
|
try: |
|
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: |
|
return await f.read() |
|
except: |
|
self.logger.error( |
|
message="Failed to load content: {file_path}", |
|
tag="ERROR", |
|
force_verbose=True, |
|
params={"file_path": file_path} |
|
) |
|
return None |
|
|
|
|
|
async_db_manager = AsyncDatabaseManager() |
|
|