File size: 4,101 Bytes
19b3da3 7fbdac4 19b3da3 10230ea 19b3da3 f1235a4 19b3da3 10230ea 19b3da3 f1235a4 19b3da3 10230ea 19b3da3 f1235a4 19b3da3 10230ea 19b3da3 42ef134 19b3da3 f1235a4 19b3da3 10230ea 19b3da3 10230ea 19b3da3 7fbdac4 19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import traceback
from typing import Dict, List, Optional
import requests
from pydash import includes
from requests.adapters import HTTPAdapter, Retry
from internals.data.task import Task
from internals.util.config import api_endpoint, api_headers
from internals.util.slack import Slack
class RetryRequest:
def __new__(cls):
obj = Retry(total=5, backoff_factor=2, status_forcelist=[500, 502, 503, 504])
session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=obj))
return session
def updateSource(sourceId, userId, state):
url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
headers = {
"Content-Type": "application/json",
"user-id": str(userId),
**api_headers(),
}
data = {"state": state}
try:
with RetryRequest() as session:
response = session.patch(url, headers=headers, json=data, timeout=10)
except requests.exceptions.Timeout:
print("Request timed out while updating source")
except requests.exceptions.RequestException as e:
print(f"Error while updating source: {e}")
return
def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
url = (
api_endpoint()
+ "/autodraft-crecoai/source/"
+ str(sourceId)
+ "/generatedImages"
)
headers = {
"Content-Type": "application/json",
"user-id": str(userId),
**api_headers(),
}
data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
try:
with RetryRequest() as session:
session.patch(url, headers=headers, json=data)
# print("save generation response", response)
except requests.exceptions.Timeout:
print("Request timed out while saving image")
except requests.exceptions.RequestException as e:
print("Failed to mark source as active: ", e)
return
return
def getStyles() -> Optional[Dict]:
url = api_endpoint() + "/autodraft-crecoai/style"
try:
with RetryRequest() as session:
response = session.get(
url,
timeout=10,
headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
)
return response.json()
except requests.exceptions.Timeout:
print("Request timed out while fetching styles")
except requests.exceptions.RequestException as e:
print(f"Error while fetching styles: {e}")
raise e
return None
def getCharacters(model_id: str) -> Optional[List]:
url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
try:
with RetryRequest() as session:
response = session.get(url, timeout=10, headers=api_headers())
response = response.json()
response = response["data"]["characters"]
return response
except requests.exceptions.Timeout:
print("Request timed out while fetching characters")
except Exception as e:
print(f"Error while fetching characters: {e}")
return None
def update_db_source_failed(sourceId, userId):
updateSource(sourceId, userId, "FAILED")
def update_db(func):
def caller(*args, **kwargs):
task = None
for arg in args:
if type(arg) is Task:
task = arg
break
if task is None:
raise Exception("First argument must be a Task object")
try:
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
rargs = func(*args, **kwargs)
has_nsfw = rargs.get("has_nsfw", False)
updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED")
saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw)
return rargs
except Exception as e:
print("Error processing image: {}".format(str(e)))
traceback.print_exc()
slack = Slack()
slack.error_alert(task, e)
updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
return caller
|