12345678910111213141516171819202122232425 |
- #!/usr/bin/env python
- """tfmodel: Base class for tensorflow based models."""
- from pycs.utils import Errorable
- from .utils import create_session_from_config
- class TFModelException(Exception):
- def __init__(self, message):
- Exception.__init__(self, message)
- class TFModel(Errorable):
- def __init__(self, config):
- Errorable.__init__(self)
- self.tf_session = create_session_from_config(config)
- if self.tf_session is None:
- self.last_error = 'Session creation failed.'
- self.tf_graph = self.tf_session.graph
- def close(self):
- self.tf_session.close()
|