utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #!/usr/bin/env python
  2. """utils: Various helper functions."""
  3. import logging
  4. import os.path
  5. import sys
  6. import google.protobuf.message
  7. import tensorflow as tf
  8. from google.protobuf import text_format
  9. def create_session_from_config(config):
  10. try:
  11. # (1) Frozen graph
  12. if "model-frozengraph" in config.keys():
  13. full_path = os.path.join(config["distribution-root"], config["model-frozengraph"])
  14. logging.debug("Loading frozen graph: %s" % full_path)
  15. graph = tf.Graph()
  16. with graph.as_default():
  17. graph_definition = tf.compat.v1.GraphDef()
  18. with tf.io.gfile.GFile(full_path, 'rb') as graph_file:
  19. logging.debug("Opened file, deserializing...")
  20. serialized_graph = graph_file.read()
  21. graph_definition.ParseFromString(serialized_graph)
  22. tf.import_graph_def(graph_definition)
  23. logging.debug("Done, creating session.")
  24. return _create_tf_session(graph=graph)
  25. # (2) GraphDef + Checkpoint
  26. elif "model-graphdef" in config.keys():
  27. full_path_def = os.path.join(config["distribution-root"], config["model-graphdef"])
  28. full_path_ckpt = os.path.join(config["distribution-root"], config["model-checkpoint"])
  29. logging.debug("Loading graphdef: %s" % full_path_def)
  30. graph = tf.Graph()
  31. with graph.as_default():
  32. graph_definition = tf.compat.v1.GraphDef()
  33. if full_path_def.endswith('pb'):
  34. with tf.io.gfile.GFile(full_path_def, 'rb') as graph_file:
  35. logging.debug("Opened file, deserializing...")
  36. serialized_graph = graph_file.read()
  37. graph_definition.ParseFromString(serialized_graph)
  38. elif full_path_def.endswith('graph') or full_path_def.endswith('pbtxt'):
  39. with tf.io.gfile.GFile(full_path_def, 'r') as graph_file:
  40. logging.debug("Opened file, deserializing...")
  41. text_format.Merge(graph_file.read(), graph_definition)
  42. else:
  43. raise Exception("Unknown file type: %s" % full_path_def)
  44. tf.import_graph_def(graph_definition)
  45. logging.debug("Done, creating session.")
  46. session = _create_tf_session(graph=graph)
  47. logging.debug('Restoring checkpoint %s' % full_path_ckpt)
  48. saver = tf.compat.v1.train.Saver()
  49. saver.restore(sess=session, save_path=full_path_ckpt)
  50. return session
  51. # (3) Metagraph
  52. elif "model-metagraph" in config.keys():
  53. full_path_meta = os.path.join(config["distribution-root"], config["model-metagraph"])
  54. full_path_ckpt = os.path.join(config["distribution-root"], config["model-checkpoint"])
  55. logging.debug('Importing metagraph, creating session...')
  56. session = _create_tf_session()
  57. logging.debug('Loading %s' % full_path_meta)
  58. saver = tf.compat.v1.train.import_meta_graph(full_path_meta, clear_devices=True, import_scope="import")
  59. logging.debug('Restoring checkpoint %s' % full_path_ckpt)
  60. saver.restore(sess=session, save_path=full_path_ckpt)
  61. return session
  62. except OSError as os_error:
  63. logging.error("Error while attempting to load: %s" % os_error)
  64. logging.error("Config: %s" % config)
  65. except google.protobuf.message.DecodeError as decode_error:
  66. logging.error("Error while attempting to load: %s" % decode_error)
  67. logging.error("Config: %s" % config)
  68. except:
  69. t = sys.exc_info()[0]
  70. v = sys.exc_info()[1]
  71. logging.error("Error while attempting to load: %s, %s" % (t, v))
  72. logging.error("Config: %s" % config)
  73. return None
  74. def _create_tf_session(graph=None):
  75. # tf_config = tf.compat.v1.ConfigProto()
  76. # tf_config.gpu_options.allow_growth = True
  77. gpu_devices = tf.config.experimental.list_physical_devices('GPU')
  78. for device in gpu_devices:
  79. tf.config.experimental.set_memory_growth(device, True)
  80. with graph.as_default():
  81. print(graph)
  82. # TODO convert graph ?
  83. pass
  84. if graph is not None:
  85. session = tf.compat.v1.Session(graph=graph)
  86. else:
  87. session = tf.compat.v1.Session()
  88. return session