Gabriel commited on
Commit
2d4e52a
·
1 Parent(s): c155b0b

Fixing api not working #4

Browse files
.gitignore CHANGED
@@ -7,7 +7,6 @@ __pycache__/
7
  *.py[cod]
8
 
9
  vis_data/
10
- notebooks/
11
  output/
12
  my_xml_filename.xml
13
  models/RmtDet_regions/epoch_12.pth
@@ -32,7 +31,6 @@ data/
32
 
33
  #mlflow
34
  mlruns/
35
- test.ipynb
36
 
37
  #models
38
  models--Riksarkivet--HTR_pipeline_models/
 
7
  *.py[cod]
8
 
9
  vis_data/
 
10
  output/
11
  my_xml_filename.xml
12
  models/RmtDet_regions/epoch_12.pth
 
31
 
32
  #mlflow
33
  mlruns/
 
34
 
35
  #models
36
  models--Riksarkivet--HTR_pipeline_models/
notebooks/demo_api.ipynb ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 11,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Loaded as API: http://127.0.0.1:7860/ ✔\n",
13
+ "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n",
14
+ "<PcGts xmlns=\"http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15\" xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\" xsi:schemaLocation=\"http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15 http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15/pagecontent.xsd\">\n",
15
+ " <Metadata>\n",
16
+ " <Creator>Swedish National Archives</Creator>\n",
17
+ " <Created>2023-11-23, 09:41:42</Created>\n",
18
+ " </Metadata>\n",
19
+ " <Page imageFilename=\"page_xml.xml\" imageWidth=\"1629\" imageHeight=\"626\">\n",
20
+ " <TextRegion id=\"region_0\" custom=\"readingOrder {index:0;}\">\n",
21
+ " <Coords points=\"0,313 14,391 0,563 28,608 70,588 90,611 153,570 279,556 397,599 970,625 1026,585 1132,583 1191,611 1353,573 1599,575 1628,541 1595,482 1505,459 1547,451 1580,415 1596,316 1579,288 1536,295 1491,257 1445,169 1352,132 1171,120 1125,98 926,98 847,71 609,57 173,67 122,86 92,224\"/>\n",
22
+ " <TextLine id=\"line_region_0_0\" custom=\"readingOrder {index:0;}\">\n",
23
+ " <Coords points=\"124,121 134,200 479,207 617,224 681,248 782,231 1070,252 1110,284 1351,244 1364,209 1356,166 1332,152 965,135 846,81 744,111 600,104 514,54 423,87 338,72 317,54 247,54 241,66 154,79\"/>\n",
24
+ " <TextEquiv>\n",
25
+ " <Unicode>Hushållspenningar</Unicode>\n",
26
+ " </TextEquiv>\n",
27
+ " <PredScore pred_score=\"0.9796\"/>\n",
28
+ " </TextLine>\n",
29
+ " <TextLine id=\"line_region_0_1\" custom=\"readingOrder {index:1;}\">\n",
30
+ " <Coords points=\"26,331 32,371 1035,394 1191,410 1273,453 1370,424 1547,419 1573,391 1571,330 1520,313 1169,291 985,315 846,304 745,268 580,297 274,283 142,250 72,263 32,293\"/>\n",
31
+ " <TextEquiv>\n",
32
+ " <Unicode>Priminieration å Snällporten</Unicode>\n",
33
+ " </TextEquiv>\n",
34
+ " <PredScore pred_score=\"0.9221\"/>\n",
35
+ " </TextLine>\n",
36
+ " <TextLine id=\"line_region_0_2\" custom=\"readingOrder {index:2;}\">\n",
37
+ " <Coords points=\"0,452 0,539 28,570 271,539 676,546 735,592 825,596 985,553 1581,572 1614,535 1581,493 1356,457 1189,469 948,437 800,463 648,436 558,454 441,425 228,439 94,390 30,410\"/>\n",
38
+ " <TextEquiv>\n",
39
+ " <Unicode>Gasverksräkning för för ä kronor</Unicode>\n",
40
+ " </TextEquiv>\n",
41
+ " <PredScore pred_score=\"0.9599\"/>\n",
42
+ " </TextLine>\n",
43
+ " </TextRegion>\n",
44
+ " </Page>\n",
45
+ "</PcGts>\n",
46
+ "\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "from gradio_client import Client # pip install gradio_client\n",
52
+ "\n",
53
+ "# Change url to your client (localhost: http://127.0.0.1:7860/)\n",
54
+ "\n",
55
+ "client = Client(\"http://127.0.0.1:7860/\")\n",
56
+ "job = client.submit(\n",
57
+ " \"https://github.com/Swedish-National-Archives-AI-lab/htrflow_core/blob/main/data/raw/demo_image.jpg?raw=true\", \n",
58
+ " \"Riksarkivet/satrn_htr\",\n",
59
+ " api_name=\"/run_htr_pipeline\",\n",
60
+ ")\n",
61
+ "\n",
62
+ "print(job.result())"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": []
71
+ }
72
+ ],
73
+ "metadata": {
74
+ "kernelspec": {
75
+ "display_name": "venv",
76
+ "language": "python",
77
+ "name": "python3"
78
+ },
79
+ "language_info": {
80
+ "codemirror_mode": {
81
+ "name": "ipython",
82
+ "version": 3
83
+ },
84
+ "file_extension": ".py",
85
+ "mimetype": "text/x-python",
86
+ "name": "python",
87
+ "nbconvert_exporter": "python",
88
+ "pygments_lexer": "ipython3",
89
+ "version": "3.10.9"
90
+ },
91
+ "orig_nbformat": 4
92
+ },
93
+ "nbformat": 4,
94
+ "nbformat_minor": 2
95
+ }
src/htr_pipeline/gradio_backend.py CHANGED
@@ -88,8 +88,8 @@ class FastTrack:
88
  ): # >= 0 means on the polygon or inside
89
  return text
90
 
91
- def segment_to_xml_api(self, image):
92
- rendered_xml = self.pipeline.running_htr_pipeline(image)
93
  return rendered_xml
94
 
95
 
@@ -202,7 +202,8 @@ def compute_cer_a_and_b_with_gt(run_a, run_b, run_gt):
202
  return f"A & B -> GT: {round(cer_metric.compute(predictions=[text_run_a], references=[text_run_gt]), 4)}"
203
 
204
  else:
205
- return f"A -> GT: {round(cer_metric.compute(predictions=[text_run_a], references=[text_run_gt]), 4)}, B -> GT {round(cer_metric.compute(predictions=[text_run_b], references=[text_run_gt]), 4)}"
 
206
 
207
 
208
  def temporary_xml_parser(page_xml):
 
88
  ): # >= 0 means on the polygon or inside
89
  return text
90
 
91
+ def segment_to_xml_api(self, image, model="Riksarkivet/satrn_htr"):
92
+ rendered_xml = self.pipeline.running_htr_pipeline(image, model)
93
  return rendered_xml
94
 
95
 
 
202
  return f"A & B -> GT: {round(cer_metric.compute(predictions=[text_run_a], references=[text_run_gt]), 4)}"
203
 
204
  else:
205
+ return f"A -> GT: {round(cer_metric.compute(predictions=[text_run_a], references=[text_run_gt]), 4)} \
206
+ , B -> GT {round(cer_metric.compute(predictions=[text_run_b], references=[text_run_gt]), 4)}"
207
 
208
 
209
  def temporary_xml_parser(page_xml):
src/htr_pipeline/pipeline.py CHANGED
@@ -62,6 +62,7 @@ class PipelineInterface(Protocol):
62
  def running_htr_pipeline(
63
  self,
64
  input_image: np.ndarray,
 
65
  pred_score_threshold_regions: float = 0.4,
66
  pred_score_threshold_lines: float = 0.4,
67
  containments_threshold: float = 0.5,
 
62
  def running_htr_pipeline(
63
  self,
64
  input_image: np.ndarray,
65
+ htr_tool_transcriber_model_dropdown: str,
66
  pred_score_threshold_regions: float = 0.4,
67
  pred_score_threshold_lines: float = 0.4,
68
  containments_threshold: float = 0.5,
tabs/htr_tool.py CHANGED
@@ -1,8 +1,5 @@
1
  import os
2
- SECRET_KEY = os.environ.get("HUB_TOKEN", False)
3
- if SECRET_KEY:
4
- from helper.utils import TrafficDataHandler
5
-
6
  import gradio as gr
7
 
8
  from helper.examples.examples import DemoImages
@@ -17,6 +14,11 @@ from src.htr_pipeline.gradio_backend import (
17
  upload_file,
18
  )
19
 
 
 
 
 
 
20
  model_loader = SingletonModelLoader()
21
  fast_track = FastTrack(model_loader)
22
  images_for_demo = DemoImages()
@@ -240,7 +242,7 @@ with gr.Blocks() as htr_tool_tab:
240
 
241
  htr_pipeline_button_api.click(
242
  fast_track.segment_to_xml_api,
243
- inputs=[fast_track_input_region_image],
244
  outputs=[xml_rendered_placeholder_for_api],
245
  queue=False,
246
  api_name="run_htr_pipeline",
 
1
  import os
2
+
 
 
 
3
  import gradio as gr
4
 
5
  from helper.examples.examples import DemoImages
 
14
  upload_file,
15
  )
16
 
17
+ SECRET_KEY = os.environ.get("HUB_TOKEN", False)
18
+ if SECRET_KEY:
19
+ from helper.utils import TrafficDataHandler
20
+
21
+
22
  model_loader = SingletonModelLoader()
23
  fast_track = FastTrack(model_loader)
24
  images_for_demo = DemoImages()
 
242
 
243
  htr_pipeline_button_api.click(
244
  fast_track.segment_to_xml_api,
245
+ inputs=[fast_track_input_region_image, htr_tool_transcriber_model_dropdown],
246
  outputs=[xml_rendered_placeholder_for_api],
247
  queue=False,
248
  api_name="run_htr_pipeline",
tabs/stepwise_htr_tool.py CHANGED
@@ -1,8 +1,4 @@
1
  import os
2
- SECRET_KEY = os.environ.get("HUB_TOKEN", False)
3
- if SECRET_KEY:
4
- from helper.utils import TrafficDataHandler
5
-
6
  import shutil
7
  from difflib import Differ
8
 
@@ -12,6 +8,11 @@ import gradio as gr
12
  from helper.examples.examples import DemoImages
13
  from src.htr_pipeline.gradio_backend import CustomTrack, SingletonModelLoader
14
 
 
 
 
 
 
15
  model_loader = SingletonModelLoader()
16
 
17
  custom_track = CustomTrack(model_loader)
 
1
  import os
 
 
 
 
2
  import shutil
3
  from difflib import Differ
4
 
 
8
  from helper.examples.examples import DemoImages
9
  from src.htr_pipeline.gradio_backend import CustomTrack, SingletonModelLoader
10
 
11
+ SECRET_KEY = os.environ.get("HUB_TOKEN", False)
12
+ if SECRET_KEY:
13
+ from helper.utils import TrafficDataHandler
14
+
15
+
16
  model_loader = SingletonModelLoader()
17
 
18
  custom_track = CustomTrack(model_loader)