|
@@ -1,9 +1,31 @@
|
|
|
+import glob
|
|
|
import json
|
|
|
+import os
|
|
|
+import re
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
from pycs import db
|
|
|
from pycs.database.base import NamedBaseModel
|
|
|
+from pycs.interfaces.LabelProvider import LabelProvider as LabelProviderInterface
|
|
|
+
|
|
|
+
|
|
|
+def __find_files(root: str, config_regex=re.compile(r'^configuration(\d+)?\.json$')):
|
|
|
+ # list folders in labels/
|
|
|
+ for folder in Path(root).glob('*'):
|
|
|
+ # list files
|
|
|
+ for file_path in folder.iterdir():
|
|
|
+
|
|
|
+ # filter configuration files
|
|
|
+ if not file_path.isfile():
|
|
|
+ continue
|
|
|
+
|
|
|
+ if config_regex.match(file_path.name) is None:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # yield element
|
|
|
+ yield folder, file_path
|
|
|
+
|
|
|
|
|
|
class LabelProvider(NamedBaseModel):
|
|
|
"""
|
|
@@ -11,25 +33,63 @@ class LabelProvider(NamedBaseModel):
|
|
|
"""
|
|
|
|
|
|
description = db.Column(db.String)
|
|
|
- root_folder = db.Column(db.String, nullable=False, unique=True)
|
|
|
+ root_folder = db.Column(db.String, nullable=False)
|
|
|
+ configuration_file = db.Column(db.String, nullable=False)
|
|
|
|
|
|
# relationships to other models
|
|
|
projects = db.relationship("Project", backref="label_provider", lazy="dynamic")
|
|
|
- serialize_rules = ('-projects',)
|
|
|
+
|
|
|
+ # contraints
|
|
|
+ __table_args__ = (
|
|
|
+ db.UniqueConstraint('root_folder', 'configuration_file'),
|
|
|
+ )
|
|
|
+
|
|
|
+ serialize_only = NamedBaseModel.serialize_only + (
|
|
|
+ "description",
|
|
|
+ "root_folder",
|
|
|
+ "configuration_file",
|
|
|
+ )
|
|
|
|
|
|
@classmethod
|
|
|
- def discover(cls, root: Path, config_name: str = "configuration.json"):
|
|
|
+ def discover(cls, root: Path):
|
|
|
|
|
|
- for folder in Path(root).glob("*"):
|
|
|
- with open(folder / config_name) as f:
|
|
|
+ for folder, conf_path in __find_files(root):
|
|
|
+ with open(conf_path) as f:
|
|
|
config = json.load(f)
|
|
|
|
|
|
- # extract data
|
|
|
- name = config['name']
|
|
|
- description = config.get('description', None)
|
|
|
+ provider, _ = cls.get_or_create(
|
|
|
+ root_folder=str(folder),
|
|
|
+ configuration_file=conf_path.name
|
|
|
+ )
|
|
|
|
|
|
- provider, _ = cls.get_or_create(root_folder=str(folder))
|
|
|
- provider.name = name
|
|
|
- provider.description = description
|
|
|
+ provider.name = config['name']
|
|
|
+
|
|
|
+ # returns None if not present
|
|
|
+ provider.description = config.get('description')
|
|
|
|
|
|
db.session.commit()
|
|
|
+
|
|
|
+ @property
|
|
|
+ def configuration_path(self) -> Path:
|
|
|
+ return Path(self.root_folder, self.configuration_file)
|
|
|
+
|
|
|
+ def load(self) -> LabelProviderInterface:
|
|
|
+ """
|
|
|
+ load configuration.json and create an instance from the included code object
|
|
|
+
|
|
|
+ :return: LabelProvider instance
|
|
|
+ """
|
|
|
+ # load configuration.json
|
|
|
+ with open(self.configuration_path, 'r') as configuration_file:
|
|
|
+ configuration = json.load(configuration_file)
|
|
|
+
|
|
|
+ # load code
|
|
|
+ code_path = Path(self.root_folder, configuration['code']['module']).resolve()
|
|
|
+ module_name = code_path.replace('/', '.').replace('\\', '.')
|
|
|
+ class_name = configuration['code']['class']
|
|
|
+
|
|
|
+ imported_module = __import__(module_name, fromlist=[class_name])
|
|
|
+ class_attr = getattr(imported_module, class_name)
|
|
|
+
|
|
|
+ # return instance
|
|
|
+ return class_attr(self.root_folder, configuration)
|