Browse Source

YOLO for both img & vdo

master
sipp11 5 years ago
parent
commit
cb027a092b
  1. 10
      examples/yolo_img_obj_detector.py
  2. 162
      examples/yolo_vdo_obj_detector.py

10
examples/yolo_obj_detector.py → examples/yolo_img_obj_detector.py

@ -1,9 +1,9 @@
"""USAGE
python examples/yolo_obj_detector.py \
-c ~/dev/obj-tracking/yolov3.cfg \
-w ~/dev/obj-tracking/yolov3.weights \
-cl ~/dev/obj-tracking/yolo/darknet/data/coco.names \
-i ~/dev/obj-tracking/person.jpg
python examples/yolo_img_obj_detector.py \
-c ~/syncthing/dropbox/handai/obj_tracking/mytrain.cfg \
-w ~/syncthing/dropbox/handai/obj_tracking/mytrain_final.weights \
-cl ~/syncthing/dropbox/handai/obj_tracking/mytrain.names \
-i ~/syncthing/dropbox/handai/obj_tracking/person.jpg
"""
import cv2
import argparse

162
examples/yolo_vdo_obj_detector.py

@ -0,0 +1,162 @@
"""USAGE
python yolo_vdo_obj_detector.py \
-c ~/syncthing/dropbox/handai/obj_tracking/mytrain.cfg \
-w ~/syncthing/dropbox/handai/obj_tracking/mytrain_final.weights \
-cl ~/syncthing/dropbox/handai/obj_tracking/mytrain.names \
-i ~/syncthing/dropbox/handai/data/5min.mp4
"""
from imutils.video import FPS
import cv2
import csv
import argparse
import numpy as np
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--input", required=True, help="path to input vdo")
ap.add_argument("-c", "--config", required=True, help="path to yolo config file")
ap.add_argument(
"-w", "--weights", required=True, help="path to yolo pre-trained weights"
)
ap.add_argument(
"-cl", "--classes", required=True, help="path to text file containing class names"
)
args = ap.parse_args()
def get_output_layers(net):
layer_names = net.getLayerNames()
output_layers = [layer_names[i[0] - 1] for i in net.getUnconnectedOutLayers()]
return output_layers
def draw_prediction(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
label = str(classes[class_id])
color = COLORS[class_id]
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
classes = None
with open(args.classes, "r") as f:
classes = [line.strip() for line in f.readlines()]
COLORS = np.random.uniform(0, 255, size=(len(classes), 3))
net = cv2.dnn.readNet(args.weights, args.config)
vs = cv2.VideoCapture(args.input)
_fps = vs.get(cv2.CAP_PROP_FPS)
Width = vs.get(cv2.CAP_PROP_FRAME_WIDTH)
Height = vs.get(cv2.CAP_PROP_FRAME_HEIGHT)
# Width = Height = None
scale = 0.00392
print(f'{_fps} fps {Width}x{Height} px')
writer = None
# initialize the list of object trackers and corresponding class
# labels
# trackers = []
labels = []
# start the frames per second throughput estimator
fps = FPS().start()
frame_count = 0
f = open('yolo_output_txt.csv', 'wt')
fieldnames = ['frame', 'what', 'x', 'y', 'w', 'h']
cf = csv.DictWriter(f, fieldnames=fieldnames)
cf.writeheader()
# loop over frames from the video file stream
while True:
# grab the next frame from the video file
(grabbed, frame) = vs.read()
frame_count += 1
_duration = frame_count / _fps
# check to see if we have reached the end of the video file
if frame is None:
break
blob = cv2.dnn.blobFromImage(frame, scale, (416, 416), (0, 0, 0), True, crop=False)
net.setInput(blob)
outs = net.forward(get_output_layers(net))
class_ids = []
confidences = []
boxes = []
conf_threshold = 0.5
nms_threshold = 0.4
for out in outs:
for detection in out:
scores = detection[5:]
class_id = np.argmax(scores)
confidence = scores[class_id]
if confidence > 0.5:
center_x = int(detection[0] * Width)
center_y = int(detection[1] * Height)
w = int(detection[2] * Width)
h = int(detection[3] * Height)
x = center_x - w / 2
y = center_y - h / 2
class_ids.append(class_id)
confidences.append(float(confidence))
boxes.append([x, y, w, h])
indices = cv2.dnn.NMSBoxes(boxes, confidences, conf_threshold, nms_threshold)
for i in indices:
i = i[0]
box = boxes[i]
x = round(box[0])
y = round(box[1])
w = round(box[2])
h = round(box[3])
_cls_id = class_ids[i]
_row = {
'frame': frame_count,
'what': str(classes[_cls_id]),
'x': x,
'y': y,
'w': w,
'h': h,
}
cf.writerow(_row)
draw_prediction(
frame,
_cls_id,
confidences[i],
round(x),
round(y),
round(x + w),
round(y + h),
)
# show the output frame
cv2.imshow("Frame", frame)
key = cv2.waitKey(1) & 0xFF
# if the `q` key was pressed, break from the loop
if key == ord("q"):
break
# update the FPS counter
# fps.update()
# stop the timer and display FPS information
fps.stop()
print("[INFO] elapsed time: {:.2f}".format(fps.elapsed()))
print("[INFO] approx. FPS: {:.2f}".format(fps.fps()))
# do a bit of cleanup
cv2.destroyAllWindows()
vs.release()
f.close()
Loading…
Cancel
Save