@@ -39,7 +39,8 @@ def get_model(device=None):
39
39
num_classes = 2
40
40
num_features = model .classifier [6 ].in_features
41
41
model .classifier [6 ] = nn .Linear (num_features , num_classes )
42
- model .load_state_dict (torch .load ('./models/linear_svm_alexnet_car_4.pth' ))
42
+ # model.load_state_dict(torch.load('./models/linear_svm_alexnet_car_4.pth'))
43
+ model .load_state_dict (torch .load ('./models/best_linear_svm_alexnet_car.pth' ))
43
44
model .eval ()
44
45
45
46
# 取消梯度追踪
@@ -51,6 +52,22 @@ def get_model(device=None):
51
52
return model
52
53
53
54
55
+ def draw_box_with_text (img , rect_list , score_list ):
56
+ """
57
+ 绘制边框及其分类概率
58
+ :param img:
59
+ :param rect_list:
60
+ :param score_list:
61
+ :return:
62
+ """
63
+ for i in range (len (rect_list )):
64
+ xmin , ymin , xmax , ymax = rect_list [i ]
65
+ score = score_list [i ]
66
+
67
+ cv2 .rectangle (img , (xmin , ymin ), (xmax , ymax ), color = (0 , 0 , 255 ), thickness = 1 )
68
+ cv2 .putText (img , "{:.3f}" .format (score ), (xmin , ymin ), cv2 .FONT_HERSHEY_SIMPLEX , 0.5 , (255 , 255 , 255 ), 1 )
69
+
70
+
54
71
if __name__ == '__main__' :
55
72
device = get_device ()
56
73
transform = get_transform ()
@@ -68,13 +85,18 @@ def get_model(device=None):
68
85
bndboxs = parse_xml (test_xml_path )
69
86
for bndbox in bndboxs :
70
87
xmin , ymin , xmax , ymax = bndbox
71
- cv2 .rectangle (dst , (xmin , ymin ), (xmax , ymax ), color = (0 , 255 , 0 ), thickness = 2 )
88
+ cv2 .rectangle (dst , (xmin , ymin ), (xmax , ymax ), color = (0 , 255 , 0 ), thickness = 1 )
72
89
73
90
# 候选区域建议
74
91
selectivesearch .config (gs , img , strategy = 'f' )
75
92
rects = selectivesearch .get_rects (gs )
76
93
print ('候选区域建议数目: %d' % len (rects ))
77
94
95
+ # softmax = torch.softmax()
96
+
97
+ # 保存正样本边界框以及
98
+ score_list = list ()
99
+ positive_list = list ()
78
100
for rect in rects :
79
101
xmin , ymin , xmax , ymax = rect
80
102
rect_img = img [ymin :ymax , xmin :xmax ]
@@ -86,8 +108,14 @@ def get_model(device=None):
86
108
"""
87
109
预测为汽车
88
110
"""
89
- cv2 .rectangle (dst , (xmin , ymin ), (xmax , ymax ), color = (0 , 0 , 255 ), thickness = 2 )
90
- print (rect , output )
111
+ probs = torch .softmax (output , dim = 0 ).cpu ().numpy ()
112
+
113
+ score_list .append (probs [1 ])
114
+ positive_list .append (rect )
115
+ # cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=2)
116
+ print (rect , output , probs )
117
+
118
+ draw_box_with_text (dst , positive_list , score_list )
91
119
92
120
cv2 .imshow ('img' , dst )
93
121
cv2 .waitKey (0 )
0 commit comments