utils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #!/usr/bin/env python
  2. """utils: Various helper functions."""
  3. import logging
  4. import os.path
  5. import sys
  6. import google.protobuf.message
  7. import tensorflow as tf
  8. from google.protobuf import text_format
  9. def create_session_from_config(config):
  10. try:
  11. # (0) Saved Model
  12. if 'saved_model' in config.keys():
  13. full_path = os.path.join(config["distribution-root"], config["saved_model"])
  14. logging.debug(f'Loading saved model {full_path}')
  15. # model = tf.keras.models.load_model(full_path)
  16. # print(full_path)
  17. # model = tf.saved_model.load(full_path)
  18. # tf.saved_model.save(model, full_path + 'blyat/')
  19. # print(' <!--')
  20. # print(model)
  21. # print(' -->')
  22. # (1) Frozen graph
  23. if "model-frozengraph" in config.keys():
  24. full_path = os.path.join(config["distribution-root"], config["model-frozengraph"])
  25. logging.debug("Loading frozen graph: %s" % full_path)
  26. graph = tf.Graph()
  27. with graph.as_default():
  28. graph_definition = tf.compat.v1.GraphDef()
  29. with tf.io.gfile.GFile(full_path, 'rb') as graph_file:
  30. logging.debug("Opened file, deserializing...")
  31. serialized_graph = graph_file.read()
  32. graph_definition.ParseFromString(serialized_graph)
  33. tf.import_graph_def(graph_definition)
  34. logging.debug("Done, creating session.")
  35. return _create_tf_session(graph=graph)
  36. # (2) GraphDef + Checkpoint
  37. elif "model-graphdef" in config.keys():
  38. full_path_def = os.path.join(config["distribution-root"], config["model-graphdef"])
  39. full_path_ckpt = os.path.join(config["distribution-root"], config["model-checkpoint"])
  40. logging.debug("Loading graphdef: %s" % full_path_def)
  41. graph = tf.Graph()
  42. with graph.as_default():
  43. graph_definition = tf.compat.v1.GraphDef()
  44. if full_path_def.endswith('pb'):
  45. with tf.io.gfile.GFile(full_path_def, 'rb') as graph_file:
  46. logging.debug("Opened file, deserializing...")
  47. serialized_graph = graph_file.read()
  48. graph_definition.ParseFromString(serialized_graph)
  49. elif full_path_def.endswith('graph') or full_path_def.endswith('pbtxt'):
  50. with tf.io.gfile.GFile(full_path_def, 'r') as graph_file:
  51. logging.debug("Opened file, deserializing...")
  52. text_format.Merge(graph_file.read(), graph_definition)
  53. else:
  54. raise Exception("Unknown file type: %s" % full_path_def)
  55. tf.import_graph_def(graph_definition)
  56. logging.debug("Done, creating session.")
  57. session = _create_tf_session(graph=graph)
  58. logging.debug('Restoring checkpoint %s' % full_path_ckpt)
  59. saver = tf.compat.v1.train.Saver()
  60. saver.restore(sess=session, save_path=full_path_ckpt)
  61. return session
  62. # (3) Metagraph
  63. elif "model-metagraph" in config.keys():
  64. full_path_meta = os.path.join(config["distribution-root"], config["model-metagraph"])
  65. full_path_ckpt = os.path.join(config["distribution-root"], config["model-checkpoint"])
  66. logging.debug('Importing metagraph, creating session...')
  67. session = _create_tf_session()
  68. logging.debug('Loading %s' % full_path_meta)
  69. saver = tf.compat.v1.train.import_meta_graph(full_path_meta, clear_devices=True, import_scope="import")
  70. logging.debug('Restoring checkpoint %s' % full_path_ckpt)
  71. saver.restore(sess=session, save_path=full_path_ckpt)
  72. return session
  73. except OSError as os_error:
  74. logging.error("Error while attempting to load: %s" % os_error)
  75. logging.error("Config: %s" % config)
  76. except google.protobuf.message.DecodeError as decode_error:
  77. logging.error("Error while attempting to load: %s" % decode_error)
  78. logging.error("Config: %s" % config)
  79. except:
  80. t = sys.exc_info()[0]
  81. v = sys.exc_info()[1]
  82. logging.error("Error while attempting to load: %s, %s" % (t, v))
  83. logging.error("Config: %s" % config)
  84. return None
  85. def _create_tf_session(graph=None):
  86. # tf_config = tf.compat.v1.ConfigProto()
  87. # tf_config.gpu_options.allow_growth = True
  88. gpu_devices = tf.config.experimental.list_physical_devices('GPU')
  89. for device in gpu_devices:
  90. tf.config.experimental.set_memory_growth(device, True)
  91. with graph.as_default():
  92. print(graph)
  93. # TODO convert graph ?
  94. pass
  95. if graph is not None:
  96. session = tf.compat.v1.Session(graph=graph)
  97. else:
  98. session = tf.compat.v1.Session()
  99. return session