webcamdemo.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #!/usr/bin/env python
  2. """webcam_demo: A live webcam stream with overlayed information."""
  3. import sys
  4. import cv2
  5. import numpy as np
  6. import tensorflow as tf
  7. import collections
  8. import os
  9. import six.moves.urllib as urllib
  10. import tarfile
  11. import time
  12. from threading import Thread
  13. from queue import Queue
  14. from object_detection.utils import label_map_util
  15. from object_detection.utils import visualization_utils as vis_util
  16. # What model to download.
  17. #MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
  18. MODEL_NAME = 'ssd_inception_v2_coco_11_06_2017'
  19. #MODEL_NAME = 'faster_rcnn_inception_resnet_v2_atrous_coco_11_06_2017'
  20. MODEL_FILE = MODEL_NAME + '.tar.gz'
  21. DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
  22. # Path to frozen detection graph. This is the actual model that is used for the object detection.
  23. PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
  24. # List of the strings that is used to add correct label for each box.
  25. PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
  26. NUM_CLASSES = 90
  27. class FileVideoStream:
  28. def __init__(self, device=0, queue_size=2):
  29. self.stream = cv2.VideoCapture(device)
  30. ret, _ = self.stream.read()
  31. if not ret:
  32. print('Error: read() returned false. IsOpened: %r' % (self.stream.isOpened()))
  33. self.stopped = not ret
  34. self.Q = Queue(maxsize=queue_size)
  35. def start(self):
  36. t = Thread(target=self.update, args=())
  37. t.daemon = True
  38. t.start()
  39. self.running_thread = t
  40. return self
  41. def update(self):
  42. while True:
  43. if self.stopped:
  44. return
  45. else:
  46. ret, frame = self.stream.read()
  47. if not ret:
  48. self.stop()
  49. return
  50. if not self.Q.full():
  51. for i in range(1):
  52. self.Q.put(frame)
  53. else:
  54. self.Q.get()
  55. self.Q.put(frame)
  56. time.sleep(0)
  57. def read(self):
  58. return self.Q.get()
  59. def more(self):
  60. return self.Q.qsize() > 0
  61. def stop(self):
  62. self.stopped = True
  63. def close(self):
  64. self.stop()
  65. self.running_thread.join()
  66. self.stream.release()
  67. def main(cam_id):
  68. time.time()
  69. frames = 0
  70. if not os.path.exists(MODEL_FILE):
  71. print('Downloading model...')
  72. opener = urllib.request.URLopener()
  73. opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
  74. tar_file = tarfile.open(MODEL_FILE)
  75. for file in tar_file.getmembers():
  76. file_name = os.path.basename(file.name)
  77. if 'frozen_inference_graph.pb' in file_name:
  78. tar_file.extract(file, os.getcwd())
  79. detection_graph = tf.Graph()
  80. with detection_graph.as_default():
  81. od_graph_def = tf.GraphDef()
  82. with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  83. serialized_graph = fid.read()
  84. od_graph_def.ParseFromString(serialized_graph)
  85. tf.import_graph_def(od_graph_def, name='')
  86. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
  87. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
  88. use_display_name=True)
  89. category_index = label_map_util.create_category_index(categories)
  90. def load_image_into_numpy_array(image):
  91. if isinstance(image.size, collections.Sequence):
  92. (im_width, im_height) = image.size
  93. return np.array(image.getdata(), dtype=np.uint8).reshape((im_height, im_width, 3))
  94. else:
  95. im_height = image.shape[0]
  96. im_width = image.shape[1]
  97. return np.array(image).reshape((im_height, im_width, 3))
  98. # For the sake of simplicity we will use only 2 images:
  99. # image1.jpg
  100. # image2.jpg
  101. # If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
  102. PATH_TO_TEST_IMAGES_DIR = 'test_images'
  103. TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3)]
  104. # Size, in inches, of the output images.
  105. IMAGE_SIZE = (12, 8)
  106. with detection_graph.as_default():
  107. with tf.Session(graph=detection_graph) as sess:
  108. cap = FileVideoStream(cam_id).start()
  109. if cap.stopped:
  110. print('Error: Stream is not available')
  111. quit(-1)
  112. # Definite input and output Tensors for detection_graph
  113. image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  114. # Each box represents a part of the image where a particular object was detected.
  115. detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
  116. # Each score represent how level of confidence for each of the objects.
  117. # Score is shown on the result image, together with the class label.
  118. detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
  119. detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
  120. num_detections = detection_graph.get_tensor_by_name('num_detections:0')
  121. first = time.time()
  122. last = first
  123. cv2.namedWindow("Webcam Demo", cv2.WND_PROP_FULLSCREEN)
  124. cv2.setWindowProperty("Webcam Demo",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
  125. print("Running...")
  126. while True:
  127. now = time.time()
  128. diff = now - last
  129. last = now
  130. fps_string = "FPS: %02.1f" % (1.0 / diff)
  131. frame = cap.read()
  132. rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  133. small = cv2.resize(rgb, (128, 128))
  134. # pilim = Image.fromarray(rgb)
  135. # pilim_small = pilim.resize((128, 128), resample=Image.LANCZOS)
  136. # image_np = load_image_into_numpy_array(pilim)
  137. image_np_small = load_image_into_numpy_array(rgb)
  138. # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  139. image_np_small_expanded = np.expand_dims(image_np_small, axis=0)
  140. # Actual detection.
  141. if True:
  142. (boxes, scores, classes, num) = sess.run(
  143. [detection_boxes, detection_scores, detection_classes, num_detections],
  144. feed_dict={image_tensor: image_np_small_expanded})
  145. vis_util.visualize_boxes_and_labels_on_image_array(
  146. frame,
  147. np.squeeze(boxes),
  148. np.squeeze(classes).astype(np.int32),
  149. np.squeeze(scores),
  150. category_index,
  151. use_normalized_coordinates=True,
  152. line_thickness=8)
  153. #bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
  154. cv2.putText(frame, fps_string, (0, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 255)
  155. cv2.putText(frame, "Elapsed time: %02.1fs" % (now - first), (0, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 255)
  156. cv2.imshow('Webcam Demo', frame)
  157. frames += 1
  158. if (cv2.waitKey(1) & 0xFF == ord('q')): # or (now - first) >= 20.0:
  159. print("Benchmark done. FPS avg: %02.1f" % (float(frames) / (now - first)))
  160. print("Time per frame: %.1f ms" % (1000.0 * (float(now - first) / float(frames))))
  161. print("Elapsed time: %02.1fs" % (now - first))
  162. break
  163. cap.close()
  164. cv2.destroyAllWindows()
  165. if __name__ == "__main__":
  166. cam_id = int(sys.argv[1]) if len(sys.argv) > 1 else 0
  167. main(cam_id)