features.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. #!/usr/bin/env python
  2. """features.py: Feature extraction."""
  3. # import tensorflow.contrib.slim as slim
  4. import logging
  5. import numpy as np
  6. import scipy.misc
  7. import tensorflow as tf
  8. from ..utils import TFModel
  9. class Features(TFModel):
  10. def __init__(self, config):
  11. TFModel.__init__(self, config)
  12. try:
  13. # (1) Find feature tensor
  14. feature_tensor = self.tf_graph.get_tensor_by_name(config["features-tensor"])
  15. # self.tf_feature_out = slim.flatten(feature_tensor)
  16. self.tf_feature_out = tf.keras.layers.Flatten()(feature_tensor)
  17. # (2) Find input tensor
  18. self.tf_input = self.tf_graph.get_tensor_by_name(config["input-tensor"])
  19. self.input_shape = self.tf_input.shape[1:]
  20. logging.debug("Input shape: %s" % self.input_shape)
  21. except:
  22. self._report_error("Could not access tensors by name")
  23. def extract_features(self, image):
  24. resized_image = scipy.misc.imresize(image, self.input_shape)
  25. features = self.tf_session.run(self.tf_feature_out,
  26. feed_dict={self.tf_input: np.expand_dims(resized_image, axis=0)})
  27. return features