#!/usr/bin/env python

"""utils:  Various helper functions."""

import logging
import os.path
import sys

import google.protobuf.message
import tensorflow as tf
from google.protobuf import text_format


def create_session_from_config(config):
    try:
        # (1) Frozen graph
        if "model-frozengraph" in config.keys():
            full_path = os.path.join(config["distribution-root"], config["model-frozengraph"])
            logging.debug("Loading frozen graph: %s" % full_path)
            graph = tf.Graph()
            with graph.as_default():
                graph_definition = tf.compat.v1.GraphDef()
                with tf.io.gfile.GFile(full_path, 'rb') as graph_file:
                    logging.debug("Opened file, deserializing...")
                    serialized_graph = graph_file.read()
                    graph_definition.ParseFromString(serialized_graph)
                    tf.import_graph_def(graph_definition)

            logging.debug("Done, creating session.")
            return _create_tf_session(graph=graph)

        # (2) GraphDef + Checkpoint
        elif "model-graphdef" in config.keys():
            full_path_def = os.path.join(config["distribution-root"], config["model-graphdef"])
            full_path_ckpt = os.path.join(config["distribution-root"], config["model-checkpoint"])
            logging.debug("Loading graphdef: %s" % full_path_def)
            graph = tf.Graph()
            with graph.as_default():
                graph_definition = tf.compat.v1.GraphDef()
                if full_path_def.endswith('pb'):
                    with tf.io.gfile.GFile(full_path_def, 'rb') as graph_file:
                        logging.debug("Opened file, deserializing...")
                        serialized_graph = graph_file.read()
                        graph_definition.ParseFromString(serialized_graph)
                elif full_path_def.endswith('graph') or full_path_def.endswith('pbtxt'):
                    with tf.io.gfile.GFile(full_path_def, 'r') as graph_file:
                        logging.debug("Opened file, deserializing...")
                        text_format.Merge(graph_file.read(), graph_definition)
                else:
                    raise Exception("Unknown file type: %s" % full_path_def)

                tf.import_graph_def(graph_definition)

            logging.debug("Done, creating session.")
            session = _create_tf_session(graph=graph)

            logging.debug('Restoring checkpoint %s' % full_path_ckpt)
            saver = tf.compat.v1.train.Saver()
            saver.restore(sess=session, save_path=full_path_ckpt)

            return session

        # (3) Metagraph
        elif "model-metagraph" in config.keys():
            full_path_meta = os.path.join(config["distribution-root"], config["model-metagraph"])
            full_path_ckpt = os.path.join(config["distribution-root"], config["model-checkpoint"])

            logging.debug('Importing metagraph, creating session...')
            session = _create_tf_session()

            logging.debug('Loading %s' % full_path_meta)
            saver = tf.compat.v1.train.import_meta_graph(full_path_meta, clear_devices=True, import_scope="import")

            logging.debug('Restoring checkpoint %s' % full_path_ckpt)
            saver.restore(sess=session, save_path=full_path_ckpt)

            return session
    except OSError as os_error:
        logging.error("Error while attempting to load: %s" % os_error)
        logging.error("Config: %s" % config)
    except google.protobuf.message.DecodeError as decode_error:
        logging.error("Error while attempting to load: %s" % decode_error)
        logging.error("Config: %s" % config)
    except:
        t = sys.exc_info()[0]
        v = sys.exc_info()[1]
        logging.error("Error while attempting to load: %s, %s" % (t, v))
        logging.error("Config: %s" % config)

    return None


def _create_tf_session(graph=None):
    # tf_config = tf.compat.v1.ConfigProto()
    # tf_config.gpu_options.allow_growth = True
    gpu_devices = tf.config.experimental.list_physical_devices('GPU')
    for device in gpu_devices:
        tf.config.experimental.set_memory_growth(device, True)

    with graph.as_default():
        print(graph)
        # TODO convert graph ?
        pass

    if graph is not None:
        session = tf.compat.v1.Session(graph=graph)
    else:
        session = tf.compat.v1.Session()

    return session