Spaces:
No application file
No application file
File size: 13,588 Bytes
6755a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 |
import warnings
import logging
import os
import pickle
import copy
import time
import numpy as np
import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F
from .TransNetmodels import TransNetV2
warnings.filterwarnings("ignore")
import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
## 工具函数
def complete_results_batch(
mp4_ids,
batch_mp4_scenes_index,
fps_batch,
single_frame_pred,
class_threshold,
cache_file="/data_share7/v_hyggewang/视频切换模型依赖数据/转场真实标签字典.pkl",
):
"""
single_frame_pred: [片段数, 100, 2]
return:[[xs,xs,xs...],[xs,xs..]]每个元素是对应视频的真实值
"""
cache = pickle.load(open(cache_file, "rb")) # 读取储存好的MP4真实标签
pre_index = 0
result = []
for mp4_id, index, fps in zip(mp4_ids, batch_mp4_scenes_index, fps_batch):
raw_transition_index = single_frame_pred[
pre_index : (pre_index + int(index)), 15:-15, :
].reshape(
-1, 70, 2
) # 这里得到15-85,85-155...帧信息具体切割参看dataset中验证集数据生成。
raw_transition_index = F.softmax(raw_transition_index, dim=-1) # 获得每一帧对应的类别概率
zero = torch.zeros_like(raw_transition_index)
one = torch.ones_like(raw_transition_index)
raw_transition_index = torch.where(
raw_transition_index < class_threshold, zero, one
)[
:, :, -1
] # 只获取属于1标签的预测结果
pred_label = raw_transition_index.reshape(-1) # 得到所有帧的结果
# raw_transition_index = F.softmax(raw_transition_index, dim=-1) # 获得每一帧对应的类别概率
# pred_label = torch.argmax(raw_transition_index, dim=-1).reshape(-1) # 得到最终类别
transition_index = (
torch.where(pred_label == 1)[0] + 15
) / fps # 转场帧位置(前15帧需要加入)
# 对返回结果做后处理合并相邻帧
result_transition = []
for i, transition in enumerate(transition_index):
if i == 0:
result_transition.append([transition])
else:
if abs(result_transition[-1][-1] - transition) < 0.035:
result_transition[-1].append(transition)
else:
result_transition.append([transition])
result_transition_ = [
np.mean(item, dtype=np.float16) for item in result_transition
] # 得到最终预测结果
mp4_GT_label_transition = cache[int(mp4_id)] # 储存MP4过渡转场真实标签
result.append({"真实标签": mp4_GT_label_transition, "预测标签": result_transition_})
pre_index = pre_index + int(index)
return result
### 工具函数
def pr_call(label_list, thresholds=[0.1, 0.3, 0.5, 0.7]):
"""
根据时间误差返回各个时间误差情况下的,召回度和准确度
"""
correct_num_dict = {threshold: 0 for threshold in thresholds} # 记录各个阈值下准确预测个数
result = {threshold: None for threshold in thresholds} # 记录各个阈值下,准确度和召回度
pre_positive_num = 0 # 所有样本预测正例个数
GT_positive_num = 0 # 所有样本真实正例个数
for label_dic in label_list:
true_labels, pre_labels = label_dic["真实标签"], label_dic["预测标签"]
pre_positive_num += len(pre_labels)
GT_positive_num += len(true_labels)
for threshold in thresholds:
pre_label_used = set() # 记录已经匹配的预测标签防止重复匹配
for true_label in true_labels:
matched = False # 真值是否被匹配上了
for pre_label in pre_labels:
if pre_label > true_label + threshold: # 如果预测值大于了阈值范围,则跳过剩下的预测值
break
if pre_label in pre_label_used: # 如果该标签已经被匹配上了则跳过匹配
continue
if (
(true_label - threshold)
<= pre_label
<= (true_label + threshold)
):
correct_num_dict[threshold] += 1
matched = True
if matched: # 如果真值已经被匹配上了,则跳过剩下的预测值
pre_label_used.add(pre_label) # 增加已经匹配上的标签
break
for item in correct_num_dict.items():
result[item[0]] = {
"precision": item[1] / (pre_positive_num + 1e-8),
"recall": item[1] / (GT_positive_num + 1e-8),
}
return result
class MInterface(pl.LightningModule):
def __init__(self, args):
super().__init__()
logger.info("TransNetV2 模型初始化开始...")
self.args = args
self.batch_size = self.args.batch_size
self.learning_rate = self.args.lr
self.model = TransNetV2()
## 参数初始化
for m in self.model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(m.weight)
## 使用原始权重初始化
if self.args.raw_transnet_weights is not None:
checkpoint = torch.load(self.args.raw_transnet_weights)
del checkpoint["cls_layer1.weight"]
del checkpoint["cls_layer1.bias"]
del checkpoint["cls_layer2.weight"]
del checkpoint["cls_layer2.bias"]
self.model.load_state_dict(checkpoint, strict=False)
print("载入原始模型权重")
logger.info("TransNetV2 模型初始化结束")
def training_step(self, batch, batch_idx):
frames, one_hot_gt, many_hot_gt = (
batch["frames"],
batch["one_hot"],
batch["many_hot"],
)
single_frame_pred, all_frame_pred = self.model(frames)
return single_frame_pred, all_frame_pred, one_hot_gt, many_hot_gt
def training_step_end(self, output):
(
single_frame_pred,
all_frame_pred,
one_hot_gt,
many_hot_gt,
) = output # single_frame_pred维度为[片段数, 100, 3],one_hot_gt维度为[片段数, 100]
loss_one = F.cross_entropy(
single_frame_pred[:, 15:-15, :].reshape(-1, 2),
one_hot_gt[:, 15:-15].reshape(-1),
weight=torch.tensor([0.15, 0.85], device=single_frame_pred.device).type_as(
single_frame_pred
),
)
loss_all = F.cross_entropy(
all_frame_pred[:, 15:-15, :].reshape(-1, 2),
many_hot_gt[:, 15:-15].reshape(-1),
weight=torch.tensor([0.15, 0.85], device=all_frame_pred.device).type_as(
all_frame_pred
),
)
loss_total = loss_one * 0.9 + loss_all * 0.1
self.log(
"train_loss",
loss_total,
on_epoch=True,
on_step=True,
prog_bar=True,
logger=True,
)
return loss_total
def validation_step(self, batch, batch_idx):
frames, one_hot_gt, many_hot_gt = (
batch["frames"],
batch["one_hot"],
batch["many_hot"],
)
single_frame_pred, all_frame_pred = self.model(frames)
mp4_ids = batch["mp4_ids"]
batch_mp4_scenes_index = batch["batch_mp4_scenes_index"]
fps_batch = batch["fps_batch"]
return (
single_frame_pred,
all_frame_pred,
one_hot_gt,
many_hot_gt,
mp4_ids,
batch_mp4_scenes_index,
fps_batch,
)
def validation_step_end(self, output):
(
single_frame_pred,
all_frame_pred,
one_hot_gt,
many_hot_gt,
mp4_ids,
_,
_,
) = output
# loss_one = self.lossfun(single_frame_pred.reshape(-1,3), one_hot_gt.reshape(-1))
# loss_all = self.lossfun(all_frame_pred.reshape(-1,3), many_hot_gt.reshape(-1))
loss_one = F.cross_entropy(
single_frame_pred[:, 15:-15, :].reshape(-1, 2),
one_hot_gt[:, 15:-15].reshape(-1),
weight=torch.tensor([0.15, 0.85], device=single_frame_pred.device).type_as(
single_frame_pred
),
)
loss_all = F.cross_entropy(
all_frame_pred[:, 15:-15, :].reshape(-1, 2),
many_hot_gt[:, 15:-15].reshape(-1),
weight=torch.tensor([0.15, 0.85], device=single_frame_pred.device).type_as(
single_frame_pred
),
)
loss_total = loss_one * 0.8 + loss_all * 0.2
self.log(
"val_loss",
loss_total,
on_epoch=True,
on_step=True,
prog_bar=True,
logger=True,
)
def validation_epoch_end(self, output):
start = time.time()
class_threshold_list = [0.1, 0.3, 0.5, 0.7]
# 计算每个不同的class_threshold下召准
for class_threshold in class_threshold_list:
transition_label_list = []
for output_each in output:
(
single_frame_pred,
all_frame_pred,
one_hot_gt,
many_hot_gt,
mp4_ids,
batch_mp4_scenes_index,
fps_batch,
) = output_each
transition_label_list = transition_label_list + complete_results_batch(
mp4_ids.cpu(),
batch_mp4_scenes_index.cpu(),
fps_batch.cpu(),
single_frame_pred.cpu().float(),
class_threshold,
)
custom_indicator = pr_call(
transition_label_list, thresholds=[0.05, 0.1, 0.2, 0.3]
)
self.log(
f"{class_threshold}_0.01s_P",
custom_indicator[0.05]["precision"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.01s_R",
custom_indicator[0.05]["recall"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.1s_P",
custom_indicator[0.1]["precision"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.1s_R",
custom_indicator[0.1]["recall"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.2s_P",
custom_indicator[0.2]["precision"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.2s_R",
custom_indicator[0.2]["recall"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.3s_P",
custom_indicator[0.3]["precision"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.3s_R",
custom_indicator[0.3]["recall"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
print("推理耗时:{}".format(time.time() - start))
## 优化器配置
def configure_optimizers(self):
logger.info("configure_optimizers 初始化开始...")
# 选择优化器
if self.args.optim == "SGD":
optimizer = torch.optim.SGD(
self.parameters(), lr=self.learning_rate, momentum=0.9
)
else:
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
# 选择学习率调度方式
if self.args.lr_scheduler == "OneCycleLR":
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=0.0002, verbose=True, epochs=500, steps_per_epoch=7
)
logger.info("configure_optimizers 初始化结束...")
return [optimizer], [scheduler]
elif self.args.lr_scheduler == "CosineAnnealingLR":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=200, eta_min=5e-7, verbose=True, last_epoch=-1
)
logger.info("configure_optimizers 初始化结束...")
return [optimizer], [scheduler]
elif self.args.lr_scheduler == "None":
logger.info("configure_optimizers 初始化结束...")
return optimizer
|