LabelProvider.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import glob
  2. import json
  3. import os
  4. import re
  5. from pathlib import Path
  6. from pycs import db
  7. from pycs.database.base import NamedBaseModel
  8. from pycs.interfaces.LabelProvider import LabelProvider as LabelProviderInterface
  9. def _find_files(root: str, config_regex=re.compile(r'^configuration(\d+)?\.json$')):
  10. # list folders in labels/
  11. for folder in Path(root).glob('*'):
  12. # list files
  13. for file_path in folder.iterdir():
  14. # filter configuration files
  15. if not file_path.is_file():
  16. continue
  17. if config_regex.match(file_path.name) is None:
  18. continue
  19. # yield element
  20. yield folder, file_path
  21. class LabelProvider(NamedBaseModel):
  22. """
  23. database class for label providers
  24. """
  25. description = db.Column(db.String)
  26. root_folder = db.Column(db.String, nullable=False)
  27. configuration_file = db.Column(db.String, nullable=False)
  28. # relationships to other models
  29. projects = db.relationship("Project", backref="label_provider", lazy="dynamic")
  30. # contraints
  31. __table_args__ = (
  32. db.UniqueConstraint('root_folder', 'configuration_file'),
  33. )
  34. serialize_only = NamedBaseModel.serialize_only + (
  35. "description",
  36. "root_folder",
  37. "configuration_file",
  38. )
  39. @classmethod
  40. def discover(cls, root: Path):
  41. for folder, conf_path in _find_files(root):
  42. with open(conf_path) as f:
  43. config = json.load(f)
  44. provider, _ = cls.get_or_create(
  45. root_folder=str(folder),
  46. configuration_file=conf_path.name
  47. )
  48. provider.name = config['name']
  49. # returns None if not present
  50. provider.description = config.get('description')
  51. db.session.commit()
  52. @property
  53. def configuration_path(self) -> Path:
  54. return Path(self.root_folder, self.configuration_file)
  55. def load(self) -> LabelProviderInterface:
  56. """
  57. load configuration.json and create an instance from the included code object
  58. :return: LabelProvider instance
  59. """
  60. # load configuration.json
  61. with open(self.configuration_path, 'r') as configuration_file:
  62. configuration = json.load(configuration_file)
  63. # load code
  64. code_path = str(Path(self.root_folder, configuration['code']['module']))
  65. module_name = code_path.replace('/', '.').replace('\\', '.')
  66. class_name = configuration['code']['class']
  67. imported_module = __import__(module_name, fromlist=[class_name])
  68. class_attr = getattr(imported_module, class_name)
  69. # return instance
  70. return class_attr(self.root_folder, configuration)