detector.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. #!/usr/bin/env python
  2. """Detector: Face detection implementation."""
  3. import logging
  4. import numpy as np
  5. from ..utils import TFModel
  6. # import tensorflow.contrib.slim as slim
  7. class Detector(TFModel):
  8. def __init__(self, config):
  9. TFModel.__init__(self, config)
  10. try:
  11. # (1) Find feature tensor
  12. self.tf_image_tensor = self.tf_graph.get_tensor_by_name("import/image_tensor:0")
  13. self.tf_detection_boxes = self.tf_graph.get_tensor_by_name('import/detection_boxes:0')
  14. self.tf_detection_scores = self.tf_graph.get_tensor_by_name('import/detection_scores:0')
  15. self.tf_detection_classes = self.tf_graph.get_tensor_by_name('import/detection_classes:0')
  16. self.tf_num_detections = self.tf_graph.get_tensor_by_name('import/num_detections:0')
  17. self.input_shape = self.tf_image_tensor.shape[1:].as_list()
  18. if "downscale-to" in config.keys():
  19. for i in range(len(self.input_shape)):
  20. if self.input_shape[i] is None:
  21. self.input_shape[i] = config["downscale-to"]
  22. logging.debug("Input shape: %s" % self.input_shape)
  23. except:
  24. self._report_error("Could not access tensors by name")
  25. def detect_faces(self, image):
  26. if None not in self.input_shape:
  27. resized_image = image.resize(size=self.input_shape[0:2])
  28. else:
  29. resized_image = image
  30. (boxes, scores, classes, num) = self.tf_session.run(
  31. [self.tf_detection_boxes, self.tf_detection_scores, self.tf_detection_classes, self.tf_num_detections],
  32. feed_dict={self.tf_image_tensor: np.expand_dims(resized_image, axis=0)})
  33. sample_num = int(num[0])
  34. sample_scores = scores[0][0:sample_num]
  35. sample_boxes = boxes[0][0:sample_num]
  36. filtered_boxes = sample_boxes[sample_scores > 0.5]
  37. filtered_scores = sample_scores[sample_scores > 0.5]
  38. ret_boxes = []
  39. for index, box in enumerate(filtered_boxes):
  40. score = sample_scores[index]
  41. ymin, xmin, ymax, xmax = box
  42. ret_box = {'x': float(xmin), 'y': float(ymin),
  43. 'w': float(xmax - xmin), 'h': float(ymax - ymin),
  44. 'score': float(score)}
  45. ret_boxes += [ret_box]
  46. return ret_boxes