Update core/refiner/simple_refiner.py
Browse files
core/refiner/simple_refiner.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from models import BaseModel
|
2 |
from .base_refiner import BaseRefiner
|
3 |
from utils.image_encoder import encode_image
|
|
|
|
|
4 |
class SimpleRefiner(BaseRefiner):
|
5 |
def __init__(self,
|
6 |
sys_prompt: str,
|
@@ -8,6 +10,30 @@ class SimpleRefiner(BaseRefiner):
|
|
8 |
) -> None:
|
9 |
BaseRefiner.__init__(self, sys_prompt=sys_prompt, model=model)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def refine(self, message: str, memory, image_paths=None) -> str:
|
12 |
if memory is None:
|
13 |
memory = []
|
|
|
1 |
from models import BaseModel
|
2 |
from .base_refiner import BaseRefiner
|
3 |
from utils.image_encoder import encode_image
|
4 |
+
import asyncio
|
5 |
+
|
6 |
class SimpleRefiner(BaseRefiner):
|
7 |
def __init__(self,
|
8 |
sys_prompt: str,
|
|
|
10 |
) -> None:
|
11 |
BaseRefiner.__init__(self, sys_prompt=sys_prompt, model=model)
|
12 |
|
13 |
+
async def refine_async(self, message: str, memory, image_paths=None) -> str:
|
14 |
+
if memory is None:
|
15 |
+
memory = []
|
16 |
+
else:
|
17 |
+
memory = memory.messages[1:]
|
18 |
+
|
19 |
+
user_context = [{"role": "user", "content": [
|
20 |
+
{"type": "text", "text": f"{message}"},]}]
|
21 |
+
if image_paths:
|
22 |
+
if not isinstance(image_paths, list):
|
23 |
+
image_paths = [image_paths]
|
24 |
+
for image_path in image_paths:
|
25 |
+
user_context.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path.name)}"}})
|
26 |
+
context = [{"role": "system", "content": self.sys_prompt}] + memory + [{"role": "user", "content": [
|
27 |
+
{"type": "text", "text": f"{message}"},
|
28 |
+
]}]
|
29 |
+
else:
|
30 |
+
context = [{"role": "system", "content": self.sys_prompt}] + memory + user_context
|
31 |
+
|
32 |
+
respond_task = asyncio.create_task(self.model.respond_async(context))
|
33 |
+
await respond_task
|
34 |
+
response = respond_task.result()
|
35 |
+
return response
|
36 |
+
|
37 |
def refine(self, message: str, memory, image_paths=None) -> str:
|
38 |
if memory is None:
|
39 |
memory = []
|