Spaces:
Running
Running
import math | |
import re | |
import logging | |
from dataclasses import dataclass | |
from natsort import natsort_key | |
import cv2 | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from PyQt5 import QtCore | |
from PyQt5.QtCore import Qt | |
log = logging.getLogger(__name__) | |
def imread( | |
filename, | |
flags=cv2.IMREAD_COLOR, | |
dtype=np.uint8 | |
): | |
''' | |
读取图像文件,解决cv2.imread()对非英文命名文件报错问题。 | |
''' | |
try: | |
img = np.fromfile(filename, dtype) | |
img = cv2.imdecode(img, flags) | |
return img | |
except Exception as e: | |
print(e) | |
return | |
def sort_box( | |
boxes | |
): | |
''' | |
对box按照在图片上的位置从上到下进行排序。 | |
''' | |
return boxes[boxes[:, 1].argsort(), :] | |
def crop( | |
img, | |
boxes | |
): | |
''' | |
按给定的box对img进行裁切,并返回相应的子图list。 | |
''' | |
if boxes.ndim == 1: | |
boxes = [boxes] | |
results = [] | |
for xyxy in boxes: | |
results.append(img[xyxy[1] : xyxy[3], xyxy[0] : xyxy[2], :]) | |
return results | |
class NormalizePAD: | |
''' | |
对进入recognizer进行ocr识别的图片进行padding等预处理。 | |
''' | |
def __init__(self, max_size, PAD_type='right'): | |
self.toTensor = transforms.ToTensor() | |
self.max_size = max_size | |
self.PAD_type = PAD_type | |
def __call__(self, img): | |
img = self.toTensor(img) | |
img.sub_(0.5).div_(0.5) | |
c, h, w = img.size() | |
Pad_img = torch.FloatTensor(*self.max_size).fill_(0) | |
Pad_img[:, :, :w] = img # right pad | |
if self.max_size[2] != w: # add border Pad | |
Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) | |
return Pad_img | |
def custom_mean(x): | |
''' | |
计算ocr的平均confidence score。 | |
''' | |
return x.prod()**(2.0/np.sqrt(len(x))) | |
class AttrDict(dict): | |
''' | |
保存模型参数。 | |
''' | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
def number_process(numbers, code): | |
''' | |
彩票号码前处理。 | |
''' | |
assert code in ["ssq", "cjdlt"], f"Code {code} is illegal." | |
first_line = numbers[0].strip() | |
pattern1 = re.compile("\d+") # 匹配数字 | |
pattern2 = re.compile("\*\d+$|\(.*\)|") # 匹配倍数,用于从号码中删除倍数 | |
def match_fill(string): | |
matched = pattern1.findall(string) | |
return [num.zfill(2) for num in matched] | |
results = [] | |
if first_line.startswith(("A", "1)")): | |
game_type = "single" | |
for line in numbers: | |
matched = match_fill(line[1:]) | |
if len(matched) < 7: | |
raise MissingInfoException("Matched numbers not enough.") | |
if code == "ssq": | |
red = matched[: 6] | |
blue = [matched[6]] | |
else: | |
red = matched[: 5] | |
blue = matched[5 : 7] | |
results.append((red, blue)) | |
elif first_line.startswith(("红胆", "前区胆")): | |
game_type = "complex" | |
red_required = [] | |
red_optional = [] | |
blue_required = [] | |
blue_optional = [] | |
for line in numbers: | |
line = line.strip() | |
if line.startswith(("红胆", "前区胆")): | |
red_required = match_fill(line) | |
previous = red_required | |
elif line.startswith(("红拖", "前区拖")): | |
red_optional = match_fill(line) | |
previous = red_optional | |
elif line.startswith(("后区胆")): | |
blue_required = match_fill(line) | |
previous = blue_required | |
elif line.startswith(("蓝单", "后区拖", "蓝复")): | |
blue_optional = match_fill(line) | |
previous = blue_optional | |
elif line.startswith("倍数"): | |
continue | |
else: | |
previous.extend(match_fill(line)) | |
if len(red_required) + len(red_optional) + len(blue_required) + len(blue_optional) <= 7: | |
raise MissingInfoException("Matched numbers not enough.") | |
results.append((red_required, red_optional, blue_required, blue_optional)) | |
elif first_line.startswith(("前区", "红区", "红单", "红复")): | |
game_type = "compound" | |
red = [] | |
blue = [] | |
for line in numbers: | |
line = line.strip() | |
if line.startswith(("前区", "红区", "红单", "红复")): | |
red = match_fill(line) | |
previous = red | |
elif line.startswith(("后区", "蓝区", "蓝单", "蓝复")): | |
blue = match_fill(line) | |
previous = blue | |
elif line.startswith("倍数"): | |
continue | |
else: | |
previous.extend(match_fill(line)) | |
if len(red) + len(blue) <= 7: | |
raise MissingInfoException("Matched numbers not enough.") | |
results.append((red, blue)) | |
else: | |
line = "".join(numbers) | |
line = pattern2.sub("", line) | |
matched = match_fill(line) | |
if len(matched) < 7: | |
raise MissingInfoException("Matched numbers not enough.") | |
elif len(matched) == 7: | |
game_type = "single" | |
if code == "ssq": | |
red = matched[: 6] | |
blue = [matched[6]] | |
else: | |
red = matched[: 5] | |
blue = matched[5 : 7] | |
results.append((red, blue)) | |
else: | |
game_type = "compound" | |
section_con = None | |
cons = ["-", "+", "*"] | |
for con in cons: | |
if con in line and line.count(con) == 1: | |
section_con = con | |
break | |
if not section_con: | |
raise MissingInfoException("Sections connector not found.") | |
red_half, blue_half = line.split(section_con) | |
red = match_fill(red_half) | |
blue = match_fill(blue_half) | |
results.append((red, blue)) | |
return { | |
"code" : code, | |
"game_type" : game_type, | |
"numbers" : results | |
} | |
class MissingInfoException(Exception): | |
def __init__(self, *args): | |
super().__init__(*args) | |
def issue_process(issue_string): | |
''' | |
开奖/销售期处理。 | |
''' | |
issue = re.findall("\d+", issue_string) | |
if len(issue) > 1: | |
return | |
return issue[0] | |
def winning_process(winning_number, code): | |
''' | |
中奖号码处理。 | |
''' | |
pattern = re.compile("\d+") | |
matched = pattern.findall(winning_number) | |
if len(matched) < 7: | |
raise MissingInfoException("Matched numbers not enough.") | |
if code == "ssq": | |
red = matched[: 6] | |
blue = [matched[6]] | |
else: | |
red = matched[: 5] | |
blue = matched[5 : 7] | |
return red, blue | |
def hit_check(numbers, winning_numbers): | |
''' | |
中奖号码匹配。 | |
''' | |
log.info("Winning numbers are: ", winning_numbers) | |
red_win, blue_win = winning_numbers | |
hits = [] | |
if numbers["game_type"] == "single" or numbers["game_type"] == "compound": | |
for number in numbers["numbers"]: | |
log.info("User numbers are: ", number) | |
red, blue = number | |
red_hit = sorted(list(set(red) & set(red_win)), key=natsort_key) | |
blue_hit = sorted(list(set(blue) & set(blue_win)), key=natsort_key) | |
log.info("Hit numbers are: ", (red_hit, blue_hit)) | |
hits.append((red_hit, blue_hit)) | |
else: | |
for number in numbers["numbers"]: | |
log.info("User numbers are: ", number) | |
red_required, red_optional, blue_required, blue_optional = number | |
red_required_hit = sorted(list(set(red_required) & set(red_win)), key=natsort_key) | |
red_optional_hit = sorted(list(set(red_optional) & set(red_win)), key=natsort_key) | |
blue_required_hit = sorted(list(set(blue_required) & set(blue_win)), key=natsort_key) | |
blue_optional_hit = sorted(list(set(blue_optional) & set(blue_win)), key=natsort_key) | |
log.info("Hit numbers are: ", (red_required_hit, red_optional_hit, blue_required_hit, blue_optional_hit)) | |
hits.append((red_required_hit, red_optional_hit, blue_required_hit, blue_optional_hit)) | |
return hits | |
class Result: | |
''' | |
要允许用户修改识别结果,就要有一个对应的数据结构作为“后台数据”和“前台表格”的桥梁。 | |
因为要达到的效果是不同彩票、不同玩法显示结果的格式不同, | |
导致人为修改数据时的处理很不简洁,但暂时没有想到更好的方法。 | |
''' | |
def __init__(self, code: str, issue: str, game_type: str, numbers: list, winning: tuple = None, hits: list = None): | |
self.code = code | |
self.issue = issue | |
self.game_type = game_type | |
self.numbers = numbers | |
self.winning = winning | |
self.hits = hits | |
self.fixed_headers = ["彩票类型", "开奖期", "开奖号码", "玩法"] | |
self.fixed_row = len(self.fixed_headers) | |
def fromTuple(self, t): | |
if len(t) == 3: | |
code, issue, numbers_ = t | |
game_type = numbers_["game_type"] | |
numbers = numbers_["numbers"] | |
return Result(code, issue, game_type, numbers) | |
else: | |
code, issue, winning, numbers_, hits = t | |
game_type = numbers_["game_type"] | |
numbers = numbers_["numbers"] | |
return Result(code, issue, game_type, numbers, winning, hits) | |
def toTuple(self): | |
return self.code, self.issue, {"code": self.code, "game_type": self.game_type, "numbers": self.numbers} | |
def codeConvert(self, code): | |
return "双色球" if code == "ssq" else "超级大乐透" | |
def codeRevert(self, s): | |
return "ssq" if s == "双色球" else "cjdlt" | |
def winningConvert(self, winning): | |
return " ".join(winning[0] + ["+"] + winning[1]) | |
def gameConvert(self, game): | |
convert = { | |
"single": "单式", | |
"compound": "复式", | |
"complex": "胆拖" | |
} | |
return convert[game] | |
def gameRevert(self, s): | |
revert = { | |
"单式": "single", | |
"复式": "compound", | |
"胆拖": "complex" | |
} | |
return revert[s] | |
def numbersConvert(self): | |
if self.game_type in ["single", "compound"]: | |
return [" ".join(num[0] + ["+"] + num[1]) for num in self.numbers] | |
else: | |
if self.code == "cjdlt": | |
return [" ".join(num) for num in self.numbers[0]] | |
else: | |
return [" ".join(num) for num in [self.numbers[0][i] for i in [0, 1, 3]]] | |
def hitsConvert(self): | |
if self.game_type in ["single", "compound"]: | |
return ["中" + str(len(hit[0])) + " + " + str(len(hit[1])) for hit in self.hits] | |
else: | |
if self.code == "cjdlt": | |
return ["中" + str(len(hit)) for hit in self.hits[0]] | |
else: | |
return ["中" + str(len(hit)) for hit in [self.hits[0][i] for i in [0, 1, 3]]] | |
def numbersWithHitsAndHeader(self): | |
if not self.hits: | |
return [header + ":" + num for header, num in zip(self.toHeaderList()[self.fixed_row:], self.numbersConvert())] | |
return [header + ":" + num + " (" + hit + ")" for header, num, hit in zip(self.toHeaderList()[self.fixed_row:], self.numbersConvert(), self.hitsConvert())] | |
def toHeaderList(self): | |
if self.game_type in ["single", "compound"]: | |
return self.fixed_headers + list("①②③④⑤⑥⑦⑧⑨⑩")[: len(self.numbers)] | |
else: | |
if self.code == "ssq": | |
return self.fixed_headers + ["红胆", "红拖", "蓝单" if len(self.numbers[0][3]) == 1 else "蓝复"] | |
else: | |
return self.fixed_headers + ["前区胆", "前区拖", "后区胆", "后区拖"] | |
def getData(self, index): | |
row, col = index.row(), index.column() | |
if col == 0: | |
if row == 0: | |
return self.codeConvert(self.code) | |
if row == 1: | |
return self.issue | |
if row == 2: | |
return self.winningConvert(self.winning) if self.winning else "点击查询按钮自动获取" | |
if row == 3: | |
return self.gameConvert(self.game_type) | |
return self.numbersConvert()[row - self.fixed_row] | |
elif col == 1 and self.hits and row >= self.fixed_row: | |
return self.hitsConvert()[row - self.fixed_row] | |
def setData(self, index, text): | |
row, col = index.row(), index.column() | |
if col != 0: | |
return False # 第一列以外不能修改 | |
if row >= len(self.toHeaderList()): | |
return False | |
text = text.strip() | |
if row == 0: | |
if text in ["超级大乐透", "双色球"]: | |
self.code = self.codeRevert(text) | |
return True | |
return False | |
if row == 1: | |
if text.isnumeric(): | |
self.issue = text | |
return True | |
return False | |
if row == 2: | |
return False # 中奖号码不允许修改 | |
if row == 3: | |
return False # 玩法不允许修改 | |
if self.game_type in ["single", "compound"]: | |
splits = text.split("+") | |
if len(splits) != 2: | |
return False | |
s1, s2 = splits | |
s1_, s2_ = s1.split(), s2.split() | |
for s in s1_ + s2_: | |
if not s.isnumeric(): | |
return False | |
self.numbers[row - self.fixed_row] = (s1_, s2_) | |
return True | |
else: | |
splits = text.strip().split() | |
for s in splits: | |
if not s.isnumeric(): | |
return False | |
if self.code == "ssq" and row == 5: | |
target = self.numbers[0][3] # 双色球比大乐透少了一个后区胆 | |
else: | |
target = self.numbers[0][row - self.fixed_row] | |
target.clear() # 号码保存在tuple中,不能直接修改,tuple中的元素是list,可以进行原位修改 | |
target.extend(splits) | |
return True | |
def __str__(self): | |
return f"彩票类型:{self.codeConvert(self.code)}\n" + f"开奖期:{self.issue}\n" \ | |
+ f"开奖号码:{self.winningConvert(self.winning) if self.winning else '未知'}\n" \ | |
+ f"玩法:{self.gameConvert(self.game_type)}\n" + "\n".join(self.numbersWithHitsAndHeader()) | |
class TableModel(QtCore.QAbstractTableModel): | |
def __init__(self, results, parent=None): | |
super().__init__(parent) | |
self.results = results | |
def data(self, index, role): | |
if role == Qt.ItemDataRole.DisplayRole or role == Qt.ItemDataRole.EditRole: | |
return self.results.getData(index) | |
def rowCount(self, parent=QtCore.QModelIndex()): | |
return len(self.results.toHeaderList()) | |
def columnCount(self, parent=QtCore.QModelIndex()): | |
return 2 | |
def headerData(self, section, orientation, role): | |
if role == Qt.ItemDataRole.DisplayRole: | |
if orientation == Qt.Orientation.Vertical: | |
return self.results.toHeaderList()[section] | |
return "" | |
def setData(self, index, value, role): | |
if index.isValid() and role == Qt.ItemDataRole.EditRole: | |
return self.results.setData(index, value) | |
return False | |
def flags(self, index): | |
if index.isValid(): | |
return Qt.ItemFlag.ItemIsEnabled | Qt.ItemFlag.ItemIsSelectable | Qt.ItemFlag.ItemIsEditable | |
return Qt.ItemFlag.NoItemFlags | |