#!/usr/bin/env python """features.py: Feature extraction.""" # import tensorflow.contrib.slim as slim import logging import numpy as np import scipy.misc import tensorflow as tf from ..utils import TFModel class Features(TFModel): def __init__(self, config): TFModel.__init__(self, config) try: # (1) Find feature tensor feature_tensor = self.tf_graph.get_tensor_by_name(config["features-tensor"]) # self.tf_feature_out = slim.flatten(feature_tensor) self.tf_feature_out = tf.keras.layers.Flatten()(feature_tensor) # (2) Find input tensor self.tf_input = self.tf_graph.get_tensor_by_name(config["input-tensor"]) self.input_shape = self.tf_input.shape[1:] logging.debug("Input shape: %s" % self.input_shape) except: self._report_error("Could not access tensors by name") def extract_features(self, image): resized_image = scipy.misc.imresize(image, self.input_shape) features = self.tf_session.run(self.tf_feature_out, feed_dict={self.tf_input: np.expand_dims(resized_image, axis=0)}) return features