123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- #!/usr/bin/env python
- """utils: Various helper functions."""
- import logging
- import os.path
- import sys
- import google.protobuf.message
- import tensorflow as tf
- from google.protobuf import text_format
- def create_session_from_config(config):
- try:
- # (0) Saved Model
- if 'saved_model' in config.keys():
- full_path = os.path.join(config["distribution-root"], config["saved_model"])
- logging.debug(f'Loading saved model {full_path}')
- # model = tf.keras.models.load_model(full_path)
- # print(full_path)
- # model = tf.saved_model.load(full_path)
- # tf.saved_model.save(model, full_path + 'blyat/')
- # print(' <!--')
- # print(model)
- # print(' -->')
- # (1) Frozen graph
- if "model-frozengraph" in config.keys():
- full_path = os.path.join(config["distribution-root"], config["model-frozengraph"])
- logging.debug("Loading frozen graph: %s" % full_path)
- graph = tf.Graph()
- with graph.as_default():
- graph_definition = tf.compat.v1.GraphDef()
- with tf.io.gfile.GFile(full_path, 'rb') as graph_file:
- logging.debug("Opened file, deserializing...")
- serialized_graph = graph_file.read()
- graph_definition.ParseFromString(serialized_graph)
- tf.import_graph_def(graph_definition)
- logging.debug("Done, creating session.")
- return _create_tf_session(graph=graph)
- # (2) GraphDef + Checkpoint
- elif "model-graphdef" in config.keys():
- full_path_def = os.path.join(config["distribution-root"], config["model-graphdef"])
- full_path_ckpt = os.path.join(config["distribution-root"], config["model-checkpoint"])
- logging.debug("Loading graphdef: %s" % full_path_def)
- graph = tf.Graph()
- with graph.as_default():
- graph_definition = tf.compat.v1.GraphDef()
- if full_path_def.endswith('pb'):
- with tf.io.gfile.GFile(full_path_def, 'rb') as graph_file:
- logging.debug("Opened file, deserializing...")
- serialized_graph = graph_file.read()
- graph_definition.ParseFromString(serialized_graph)
- elif full_path_def.endswith('graph') or full_path_def.endswith('pbtxt'):
- with tf.io.gfile.GFile(full_path_def, 'r') as graph_file:
- logging.debug("Opened file, deserializing...")
- text_format.Merge(graph_file.read(), graph_definition)
- else:
- raise Exception("Unknown file type: %s" % full_path_def)
- tf.import_graph_def(graph_definition)
- logging.debug("Done, creating session.")
- session = _create_tf_session(graph=graph)
- logging.debug('Restoring checkpoint %s' % full_path_ckpt)
- saver = tf.compat.v1.train.Saver()
- saver.restore(sess=session, save_path=full_path_ckpt)
- return session
- # (3) Metagraph
- elif "model-metagraph" in config.keys():
- full_path_meta = os.path.join(config["distribution-root"], config["model-metagraph"])
- full_path_ckpt = os.path.join(config["distribution-root"], config["model-checkpoint"])
- logging.debug('Importing metagraph, creating session...')
- session = _create_tf_session()
- logging.debug('Loading %s' % full_path_meta)
- saver = tf.compat.v1.train.import_meta_graph(full_path_meta, clear_devices=True, import_scope="import")
- logging.debug('Restoring checkpoint %s' % full_path_ckpt)
- saver.restore(sess=session, save_path=full_path_ckpt)
- return session
- except OSError as os_error:
- logging.error("Error while attempting to load: %s" % os_error)
- logging.error("Config: %s" % config)
- except google.protobuf.message.DecodeError as decode_error:
- logging.error("Error while attempting to load: %s" % decode_error)
- logging.error("Config: %s" % config)
- except:
- t = sys.exc_info()[0]
- v = sys.exc_info()[1]
- logging.error("Error while attempting to load: %s, %s" % (t, v))
- logging.error("Config: %s" % config)
- return None
- def _create_tf_session(graph=None):
- # tf_config = tf.compat.v1.ConfigProto()
- # tf_config.gpu_options.allow_growth = True
- gpu_devices = tf.config.experimental.list_physical_devices('GPU')
- for device in gpu_devices:
- tf.config.experimental.set_memory_growth(device, True)
- with graph.as_default():
- print(graph)
- # TODO convert graph ?
- pass
- if graph is not None:
- session = tf.compat.v1.Session(graph=graph)
- else:
- session = tf.compat.v1.Session()
- return session
|