Skip to content
This repository was archived by the owner on Nov 2, 2024. It is now read-only.

Commit bf70b69

Browse files
committed
perf(detector): 添加函数:绘制边界框及其分类概率
1 parent 0776efa commit bf70b69

File tree

1 file changed

+32
-4
lines changed

1 file changed

+32
-4
lines changed

py/car_detector.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def get_model(device=None):
3939
num_classes = 2
4040
num_features = model.classifier[6].in_features
4141
model.classifier[6] = nn.Linear(num_features, num_classes)
42-
model.load_state_dict(torch.load('./models/linear_svm_alexnet_car_4.pth'))
42+
# model.load_state_dict(torch.load('./models/linear_svm_alexnet_car_4.pth'))
43+
model.load_state_dict(torch.load('./models/best_linear_svm_alexnet_car.pth'))
4344
model.eval()
4445

4546
# 取消梯度追踪
@@ -51,6 +52,22 @@ def get_model(device=None):
5152
return model
5253

5354

55+
def draw_box_with_text(img, rect_list, score_list):
56+
"""
57+
绘制边框及其分类概率
58+
:param img:
59+
:param rect_list:
60+
:param score_list:
61+
:return:
62+
"""
63+
for i in range(len(rect_list)):
64+
xmin, ymin, xmax, ymax = rect_list[i]
65+
score = score_list[i]
66+
67+
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=1)
68+
cv2.putText(img, "{:.3f}".format(score), (xmin, ymin), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
69+
70+
5471
if __name__ == '__main__':
5572
device = get_device()
5673
transform = get_transform()
@@ -68,13 +85,18 @@ def get_model(device=None):
6885
bndboxs = parse_xml(test_xml_path)
6986
for bndbox in bndboxs:
7087
xmin, ymin, xmax, ymax = bndbox
71-
cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 255, 0), thickness=2)
88+
cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 255, 0), thickness=1)
7289

7390
# 候选区域建议
7491
selectivesearch.config(gs, img, strategy='f')
7592
rects = selectivesearch.get_rects(gs)
7693
print('候选区域建议数目: %d' % len(rects))
7794

95+
# softmax = torch.softmax()
96+
97+
# 保存正样本边界框以及
98+
score_list = list()
99+
positive_list = list()
78100
for rect in rects:
79101
xmin, ymin, xmax, ymax = rect
80102
rect_img = img[ymin:ymax, xmin:xmax]
@@ -86,8 +108,14 @@ def get_model(device=None):
86108
"""
87109
预测为汽车
88110
"""
89-
cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=2)
90-
print(rect, output)
111+
probs = torch.softmax(output, dim=0).cpu().numpy()
112+
113+
score_list.append(probs[1])
114+
positive_list.append(rect)
115+
# cv2.rectangle(dst, (xmin, ymin), (xmax, ymax), color=(0, 0, 255), thickness=2)
116+
print(rect, output, probs)
117+
118+
draw_box_with_text(dst, positive_list, score_list)
91119

92120
cv2.imshow('img', dst)
93121
cv2.waitKey(0)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy