Spaces:
Sleeping
Sleeping
""" | |
Task queue management | |
This module provides classes and functions for managing the task queue. | |
Classes: | |
QueueTask: A class representing a task in the queue. | |
TaskQueue: A class for managing the task queue. | |
""" | |
import uuid | |
import time | |
from typing import List, Tuple | |
import numpy as np | |
import requests | |
from fooocusapi.utils.file_utils import delete_output_file, get_file_serve_url | |
from fooocusapi.utils.img_utils import narray_to_base64img | |
from fooocusapi.utils.logger import logger | |
from fooocusapi.models.common.task import ImageGenerationResult, GenerationFinishReason | |
from fooocusapi.parameters import ImageGenerationParams | |
from fooocusapi.models.common.task import TaskType | |
class QueueTask: | |
""" | |
A class representing a task in the queue. | |
Attributes: | |
job_id (str): The unique identifier for the task, generated by uuid. | |
task_type (TaskType): The type of task. | |
is_finished (bool): Indicates whether the task has been completed. | |
finish_progress (int): The progress of the task completion. | |
in_queue_mills (int): The time the task was added to the queue, in milliseconds. | |
start_mills (int): The time the task started, in milliseconds. | |
finish_mills (int): The time the task finished, in milliseconds. | |
finish_with_error (bool): Indicates whether the task finished with an error. | |
task_status (str): The status of the task. | |
task_step_preview (str): A list of step previews for the task. | |
task_result (List[ImageGenerationResult]): The result of the task. | |
error_message (str): The error message, if any. | |
webhook_url (str): The webhook URL, if any. | |
""" | |
job_id: str | |
task_type: TaskType | |
req_param: ImageGenerationParams | |
is_finished: bool = False | |
finish_progress: int = 0 | |
in_queue_mills: int | |
start_mills: int = 0 | |
finish_mills: int = 0 | |
finish_with_error: bool = False | |
task_status: str | None = None | |
task_step_preview: str | None = None | |
task_result: List[ImageGenerationResult] = None | |
error_message: str | None = None | |
webhook_url: str | None = None # attribute for individual webhook_url | |
def __init__( | |
self, | |
job_id: str, | |
task_type: TaskType, | |
req_param: ImageGenerationParams, | |
webhook_url: str | None = None, | |
): | |
self.job_id = job_id | |
self.task_type = task_type | |
self.req_param = req_param | |
self.in_queue_mills = int(round(time.time() * 1000)) | |
self.webhook_url = webhook_url | |
def set_progress(self, progress: int, status: str | None): | |
""" | |
Set progress and status | |
Arguments: | |
progress {int} -- progress | |
status {str} -- status | |
""" | |
progress = min(progress, 100) | |
self.finish_progress = progress | |
self.task_status = status | |
def set_step_preview(self, task_step_preview: str | None): | |
"""set step preview | |
Set step preview | |
Arguments: | |
task_step_preview {str} -- step preview | |
""" | |
self.task_step_preview = task_step_preview | |
def set_result( | |
self, | |
task_result: List[ImageGenerationResult], | |
finish_with_error: bool, | |
error_message: str | None = None, | |
): | |
"""set result | |
Set task result | |
Arguments: | |
task_result {List[ImageGenerationResult]} -- task result | |
finish_with_error {bool} -- finish with error | |
error_message {str} -- error message | |
""" | |
if not finish_with_error: | |
self.finish_progress = 100 | |
self.task_status = "Finished" | |
self.task_result = task_result | |
self.finish_with_error = finish_with_error | |
self.error_message = error_message | |
def __str__(self) -> str: | |
return f"QueueTask(job_id={self.job_id}, task_type={self.task_type},\ | |
is_finished={self.is_finished}, finished_progress={self.finish_progress}, \ | |
in_queue_mills={self.in_queue_mills}, start_mills={self.start_mills}, \ | |
finish_mills={self.finish_mills}, finish_with_error={self.finish_with_error}, \ | |
error_message={self.error_message}, task_status={self.task_status}, \ | |
task_step_preview={self.task_step_preview}, webhook_url={self.webhook_url})" | |
class TaskQueue: | |
""" | |
TaskQueue is a queue of tasks that are waiting to be processed. | |
Attributes: | |
queue: List[QueueTask] | |
history: List[QueueTask] | |
last_job_id: str | |
webhook_url: str | |
persistent: bool | |
""" | |
queue: List[QueueTask] = [] | |
history: List[QueueTask] = [] | |
last_job_id: str = None | |
webhook_url: str | None = None | |
persistent: bool = False | |
def __init__( | |
self, | |
queue_size: int, | |
history_size: int, | |
webhook_url: str | None = None, | |
persistent: bool | None = False, | |
): | |
self.queue_size = queue_size | |
self.history_size = history_size | |
self.webhook_url = webhook_url | |
self.persistent = False if persistent is None else persistent | |
def add_task( | |
self, | |
task_type: TaskType, | |
req_param: ImageGenerationParams, | |
webhook_url: str | None = None, | |
) -> QueueTask | None: | |
""" | |
Create and add task to queue | |
:param task_type: task type | |
:param req_param: request parameters | |
:param webhook_url: webhook url | |
:returns: The created task's job_id, or None if reach the queue size limit | |
""" | |
if len(self.queue) >= self.queue_size: | |
return None | |
if isinstance(req_param, dict): | |
req_param = ImageGenerationParams(**req_param) | |
job_id = str(uuid.uuid4()) | |
task = QueueTask( | |
job_id=job_id, | |
task_type=task_type, | |
req_param=req_param, | |
webhook_url=webhook_url, | |
) | |
self.queue.append(task) | |
self.last_job_id = job_id | |
return task | |
def get_task(self, job_id: str, include_history: bool = False) -> QueueTask | None: | |
""" | |
Get task by job_id | |
:param job_id: job id | |
:param include_history: whether to include history tasks | |
:returns: The task with the given job_id, or None if not found | |
""" | |
for task in self.queue: | |
if task.job_id == job_id: | |
return task | |
if include_history: | |
for task in self.history: | |
if task.job_id == job_id: | |
return task | |
return None | |
def is_task_ready_to_start(self, job_id: str) -> bool: | |
""" | |
Check if the task is ready to start | |
:param job_id: job id | |
:returns: True if the task is ready to start, False otherwise | |
""" | |
task = self.get_task(job_id) | |
if task is None: | |
return False | |
return self.queue[0].job_id == job_id | |
def is_task_finished(self, job_id: str) -> bool: | |
""" | |
Check if the task is finished | |
:param job_id: job id | |
:returns: True if the task is finished, False otherwise | |
""" | |
task = self.get_task(job_id, True) | |
if task is None: | |
return False | |
return task.is_finished | |
def start_task(self, job_id: str): | |
""" | |
Start task by job_id | |
:param job_id: job id | |
""" | |
task = self.get_task(job_id) | |
if task is not None: | |
task.start_mills = int(round(time.time() * 1000)) | |
def finish_task(self, job_id: str): | |
""" | |
Finish task by job_id | |
:param job_id: job id | |
""" | |
task = self.get_task(job_id) | |
if task is not None: | |
task.is_finished = True | |
task.finish_mills = int(round(time.time() * 1000)) | |
# Use the task's webhook_url if available, else use the default | |
webhook_url = task.webhook_url or self.webhook_url | |
data = {"job_id": task.job_id, "job_result": []} | |
if isinstance(task.task_result, List): | |
for item in task.task_result: | |
data["job_result"].append( | |
{ | |
"url": get_file_serve_url(item.im) if item.im else None, | |
"seed": item.seed if item.seed else "-1", | |
} | |
) | |
# Send webhook | |
if task.is_finished and webhook_url: | |
try: | |
res = requests.post(webhook_url, json=data, timeout=15) | |
print(f"Call webhook response status: {res.status_code}") | |
except Exception as e: | |
print("Call webhook error:", e) | |
# Move task to history | |
self.queue.remove(task) | |
self.history.append(task) | |
# save history to database | |
if self.persistent: | |
from fooocusapi.sql_client import add_history | |
add_history( | |
params=task.req_param.to_dict(), | |
task_type=task.task_type.value, | |
task_id=task.job_id, | |
result_url=",".join([job["url"] for job in data["job_result"]]), | |
finish_reason=task.task_result[0].finish_reason.value, | |
) | |
# Clean history | |
if len(self.history) > self.history_size != 0: | |
removed_task = self.history.pop(0) | |
if isinstance(removed_task.task_result, List): | |
for item in removed_task.task_result: | |
if ( | |
isinstance(item, ImageGenerationResult) | |
and item.finish_reason == GenerationFinishReason.success | |
and item.im is not None | |
): | |
delete_output_file(item.im) | |
logger.std_info( | |
f"[TaskQueue] Clean task history, remove task: {removed_task.job_id}" | |
) | |
class TaskOutputs: | |
""" | |
TaskOutputs is a container for task outputs | |
""" | |
outputs = [] | |
def __init__(self, task: QueueTask): | |
self.task = task | |
def append(self, args: List[any]): | |
""" | |
Append output to task outputs list | |
:param args: output arguments | |
""" | |
self.outputs.append(args) | |
if len(args) >= 2: | |
if ( | |
args[0] == "preview" | |
and isinstance(args[1], Tuple) | |
and len(args[1]) >= 2 | |
): | |
number = args[1][0] | |
text = args[1][1] | |
self.task.set_progress(number, text) | |
if len(args[1]) >= 3 and isinstance(args[1][2], np.ndarray): | |
base64_preview_img = narray_to_base64img(args[1][2]) | |
self.task.set_step_preview(base64_preview_img) | |