MTO-TCP / fooocusapi /task_queue.py
ishworrsubedii's picture
Updated the latest changes
36cd99b
raw
history blame
10.9 kB
"""
Task queue management
This module provides classes and functions for managing the task queue.
Classes:
QueueTask: A class representing a task in the queue.
TaskQueue: A class for managing the task queue.
"""
import uuid
import time
from typing import List, Tuple
import numpy as np
import requests
from fooocusapi.utils.file_utils import delete_output_file, get_file_serve_url
from fooocusapi.utils.img_utils import narray_to_base64img
from fooocusapi.utils.logger import logger
from fooocusapi.models.common.task import ImageGenerationResult, GenerationFinishReason
from fooocusapi.parameters import ImageGenerationParams
from fooocusapi.models.common.task import TaskType
class QueueTask:
"""
A class representing a task in the queue.
Attributes:
job_id (str): The unique identifier for the task, generated by uuid.
task_type (TaskType): The type of task.
is_finished (bool): Indicates whether the task has been completed.
finish_progress (int): The progress of the task completion.
in_queue_mills (int): The time the task was added to the queue, in milliseconds.
start_mills (int): The time the task started, in milliseconds.
finish_mills (int): The time the task finished, in milliseconds.
finish_with_error (bool): Indicates whether the task finished with an error.
task_status (str): The status of the task.
task_step_preview (str): A list of step previews for the task.
task_result (List[ImageGenerationResult]): The result of the task.
error_message (str): The error message, if any.
webhook_url (str): The webhook URL, if any.
"""
job_id: str
task_type: TaskType
req_param: ImageGenerationParams
is_finished: bool = False
finish_progress: int = 0
in_queue_mills: int
start_mills: int = 0
finish_mills: int = 0
finish_with_error: bool = False
task_status: str | None = None
task_step_preview: str | None = None
task_result: List[ImageGenerationResult] = None
error_message: str | None = None
webhook_url: str | None = None # attribute for individual webhook_url
def __init__(
self,
job_id: str,
task_type: TaskType,
req_param: ImageGenerationParams,
webhook_url: str | None = None,
):
self.job_id = job_id
self.task_type = task_type
self.req_param = req_param
self.in_queue_mills = int(round(time.time() * 1000))
self.webhook_url = webhook_url
def set_progress(self, progress: int, status: str | None):
"""
Set progress and status
Arguments:
progress {int} -- progress
status {str} -- status
"""
progress = min(progress, 100)
self.finish_progress = progress
self.task_status = status
def set_step_preview(self, task_step_preview: str | None):
"""set step preview
Set step preview
Arguments:
task_step_preview {str} -- step preview
"""
self.task_step_preview = task_step_preview
def set_result(
self,
task_result: List[ImageGenerationResult],
finish_with_error: bool,
error_message: str | None = None,
):
"""set result
Set task result
Arguments:
task_result {List[ImageGenerationResult]} -- task result
finish_with_error {bool} -- finish with error
error_message {str} -- error message
"""
if not finish_with_error:
self.finish_progress = 100
self.task_status = "Finished"
self.task_result = task_result
self.finish_with_error = finish_with_error
self.error_message = error_message
def __str__(self) -> str:
return f"QueueTask(job_id={self.job_id}, task_type={self.task_type},\
is_finished={self.is_finished}, finished_progress={self.finish_progress}, \
in_queue_mills={self.in_queue_mills}, start_mills={self.start_mills}, \
finish_mills={self.finish_mills}, finish_with_error={self.finish_with_error}, \
error_message={self.error_message}, task_status={self.task_status}, \
task_step_preview={self.task_step_preview}, webhook_url={self.webhook_url})"
class TaskQueue:
"""
TaskQueue is a queue of tasks that are waiting to be processed.
Attributes:
queue: List[QueueTask]
history: List[QueueTask]
last_job_id: str
webhook_url: str
persistent: bool
"""
queue: List[QueueTask] = []
history: List[QueueTask] = []
last_job_id: str = None
webhook_url: str | None = None
persistent: bool = False
def __init__(
self,
queue_size: int,
history_size: int,
webhook_url: str | None = None,
persistent: bool | None = False,
):
self.queue_size = queue_size
self.history_size = history_size
self.webhook_url = webhook_url
self.persistent = False if persistent is None else persistent
def add_task(
self,
task_type: TaskType,
req_param: ImageGenerationParams,
webhook_url: str | None = None,
) -> QueueTask | None:
"""
Create and add task to queue
:param task_type: task type
:param req_param: request parameters
:param webhook_url: webhook url
:returns: The created task's job_id, or None if reach the queue size limit
"""
if len(self.queue) >= self.queue_size:
return None
if isinstance(req_param, dict):
req_param = ImageGenerationParams(**req_param)
job_id = str(uuid.uuid4())
task = QueueTask(
job_id=job_id,
task_type=task_type,
req_param=req_param,
webhook_url=webhook_url,
)
self.queue.append(task)
self.last_job_id = job_id
return task
def get_task(self, job_id: str, include_history: bool = False) -> QueueTask | None:
"""
Get task by job_id
:param job_id: job id
:param include_history: whether to include history tasks
:returns: The task with the given job_id, or None if not found
"""
for task in self.queue:
if task.job_id == job_id:
return task
if include_history:
for task in self.history:
if task.job_id == job_id:
return task
return None
def is_task_ready_to_start(self, job_id: str) -> bool:
"""
Check if the task is ready to start
:param job_id: job id
:returns: True if the task is ready to start, False otherwise
"""
task = self.get_task(job_id)
if task is None:
return False
return self.queue[0].job_id == job_id
def is_task_finished(self, job_id: str) -> bool:
"""
Check if the task is finished
:param job_id: job id
:returns: True if the task is finished, False otherwise
"""
task = self.get_task(job_id, True)
if task is None:
return False
return task.is_finished
def start_task(self, job_id: str):
"""
Start task by job_id
:param job_id: job id
"""
task = self.get_task(job_id)
if task is not None:
task.start_mills = int(round(time.time() * 1000))
def finish_task(self, job_id: str):
"""
Finish task by job_id
:param job_id: job id
"""
task = self.get_task(job_id)
if task is not None:
task.is_finished = True
task.finish_mills = int(round(time.time() * 1000))
# Use the task's webhook_url if available, else use the default
webhook_url = task.webhook_url or self.webhook_url
data = {"job_id": task.job_id, "job_result": []}
if isinstance(task.task_result, List):
for item in task.task_result:
data["job_result"].append(
{
"url": get_file_serve_url(item.im) if item.im else None,
"seed": item.seed if item.seed else "-1",
}
)
# Send webhook
if task.is_finished and webhook_url:
try:
res = requests.post(webhook_url, json=data, timeout=15)
print(f"Call webhook response status: {res.status_code}")
except Exception as e:
print("Call webhook error:", e)
# Move task to history
self.queue.remove(task)
self.history.append(task)
# save history to database
if self.persistent:
from fooocusapi.sql_client import add_history
add_history(
params=task.req_param.to_dict(),
task_type=task.task_type.value,
task_id=task.job_id,
result_url=",".join([job["url"] for job in data["job_result"]]),
finish_reason=task.task_result[0].finish_reason.value,
)
# Clean history
if len(self.history) > self.history_size != 0:
removed_task = self.history.pop(0)
if isinstance(removed_task.task_result, List):
for item in removed_task.task_result:
if (
isinstance(item, ImageGenerationResult)
and item.finish_reason == GenerationFinishReason.success
and item.im is not None
):
delete_output_file(item.im)
logger.std_info(
f"[TaskQueue] Clean task history, remove task: {removed_task.job_id}"
)
class TaskOutputs:
"""
TaskOutputs is a container for task outputs
"""
outputs = []
def __init__(self, task: QueueTask):
self.task = task
def append(self, args: List[any]):
"""
Append output to task outputs list
:param args: output arguments
"""
self.outputs.append(args)
if len(args) >= 2:
if (
args[0] == "preview"
and isinstance(args[1], Tuple)
and len(args[1]) >= 2
):
number = args[1][0]
text = args[1][1]
self.task.set_progress(number, text)
if len(args[1]) >= 3 and isinstance(args[1][2], np.ndarray):
base64_preview_img = narray_to_base64img(args[1][2])
self.task.set_step_preview(base64_preview_img)