#!/usr/bin/env python """utils: Various helper functions.""" import logging import os.path import sys import google.protobuf.message import tensorflow as tf from google.protobuf import text_format def create_session_from_config(config): try: # (0) Saved Model if 'saved_model' in config.keys(): full_path = os.path.join(config["distribution-root"], config["saved_model"]) logging.debug(f'Loading saved model {full_path}') # model = tf.keras.models.load_model(full_path) # print(full_path) # model = tf.saved_model.load(full_path) # tf.saved_model.save(model, full_path + 'blyat/') # print(' ') # (1) Frozen graph if "model-frozengraph" in config.keys(): full_path = os.path.join(config["distribution-root"], config["model-frozengraph"]) logging.debug("Loading frozen graph: %s" % full_path) graph = tf.Graph() with graph.as_default(): graph_definition = tf.compat.v1.GraphDef() with tf.io.gfile.GFile(full_path, 'rb') as graph_file: logging.debug("Opened file, deserializing...") serialized_graph = graph_file.read() graph_definition.ParseFromString(serialized_graph) tf.import_graph_def(graph_definition) logging.debug("Done, creating session.") return _create_tf_session(graph=graph) # (2) GraphDef + Checkpoint elif "model-graphdef" in config.keys(): full_path_def = os.path.join(config["distribution-root"], config["model-graphdef"]) full_path_ckpt = os.path.join(config["distribution-root"], config["model-checkpoint"]) logging.debug("Loading graphdef: %s" % full_path_def) graph = tf.Graph() with graph.as_default(): graph_definition = tf.compat.v1.GraphDef() if full_path_def.endswith('pb'): with tf.io.gfile.GFile(full_path_def, 'rb') as graph_file: logging.debug("Opened file, deserializing...") serialized_graph = graph_file.read() graph_definition.ParseFromString(serialized_graph) elif full_path_def.endswith('graph') or full_path_def.endswith('pbtxt'): with tf.io.gfile.GFile(full_path_def, 'r') as graph_file: logging.debug("Opened file, deserializing...") text_format.Merge(graph_file.read(), graph_definition) else: raise Exception("Unknown file type: %s" % full_path_def) tf.import_graph_def(graph_definition) logging.debug("Done, creating session.") session = _create_tf_session(graph=graph) logging.debug('Restoring checkpoint %s' % full_path_ckpt) saver = tf.compat.v1.train.Saver() saver.restore(sess=session, save_path=full_path_ckpt) return session # (3) Metagraph elif "model-metagraph" in config.keys(): full_path_meta = os.path.join(config["distribution-root"], config["model-metagraph"]) full_path_ckpt = os.path.join(config["distribution-root"], config["model-checkpoint"]) logging.debug('Importing metagraph, creating session...') session = _create_tf_session() logging.debug('Loading %s' % full_path_meta) saver = tf.compat.v1.train.import_meta_graph(full_path_meta, clear_devices=True, import_scope="import") logging.debug('Restoring checkpoint %s' % full_path_ckpt) saver.restore(sess=session, save_path=full_path_ckpt) return session except OSError as os_error: logging.error("Error while attempting to load: %s" % os_error) logging.error("Config: %s" % config) except google.protobuf.message.DecodeError as decode_error: logging.error("Error while attempting to load: %s" % decode_error) logging.error("Config: %s" % config) except: t = sys.exc_info()[0] v = sys.exc_info()[1] logging.error("Error while attempting to load: %s, %s" % (t, v)) logging.error("Config: %s" % config) return None def _create_tf_session(graph=None): # tf_config = tf.compat.v1.ConfigProto() # tf_config.gpu_options.allow_growth = True gpu_devices = tf.config.experimental.list_physical_devices('GPU') for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True) with graph.as_default(): print(graph) # TODO convert graph ? pass if graph is not None: session = tf.compat.v1.Session(graph=graph) else: session = tf.compat.v1.Session() return session