Spaces:
Runtime error
Runtime error
Fix utils.py
Browse files
app.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
import gradio as gr
|
2 |
-
import torch
|
3 |
import numpy as np
|
4 |
from sklearn.metrics.pairwise import cosine_similarity
|
5 |
import pandas as pd
|
6 |
-
from PIL import Image
|
7 |
import matplotlib.pyplot as plt
|
8 |
-
import matplotlib
|
9 |
|
10 |
from pipelines.detection.yolo_v8 import Yolov8Pipeline
|
11 |
from pipelines.detection.yolo_stamp import YoloStampPipeline
|
@@ -41,17 +39,21 @@ def doc_predict(image, det_choice, seg_choice, emb_choice):
|
|
41 |
cropped_stamp = image.crop(box.tolist())
|
42 |
segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp)
|
43 |
|
44 |
-
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
|
49 |
-
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
embeddings = []
|
57 |
if emb_choice == 'vits8':
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import numpy as np
|
3 |
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
import pandas as pd
|
5 |
+
from PIL import Image
|
6 |
import matplotlib.pyplot as plt
|
|
|
7 |
|
8 |
from pipelines.detection.yolo_v8 import Yolov8Pipeline
|
9 |
from pipelines.detection.yolo_stamp import YoloStampPipeline
|
|
|
39 |
cropped_stamp = image.crop(box.tolist())
|
40 |
segmented_stamps.append(dlv3(cropped_stamp) if seg_choice else cropped_stamp)
|
41 |
|
42 |
+
if len(segmented_stamps) != 0:
|
43 |
+
widths, heights = zip(*(i.size for i in segmented_stamps))
|
44 |
|
45 |
+
total_width = sum(widths)
|
46 |
+
max_height = max(heights)
|
47 |
|
48 |
+
concatenated_stamps = Image.new('RGB', (total_width, max_height))
|
49 |
|
50 |
+
x_offset = 0
|
51 |
+
for im in segmented_stamps:
|
52 |
+
concatenated_stamps.paste(im, (x_offset,0))
|
53 |
+
x_offset += im.size[0]
|
54 |
+
|
55 |
+
else:
|
56 |
+
concatenated_stamps = Image.new('RGB', (0, 0))
|
57 |
|
58 |
embeddings = []
|
59 |
if emb_choice == 'vits8':
|
utils.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
def heatmap(data, row_labels, col_labels, ax=None,
|
2 |
cbar_kw=None, cbarlabel="", **kwargs):
|
3 |
"""
|
|
|
1 |
+
import matplotlib
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
|
5 |
def heatmap(data, row_labels, col_labels, ax=None,
|
6 |
cbar_kw=None, cbarlabel="", **kwargs):
|
7 |
"""
|