Model.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import json
  2. import typing as T
  3. from pathlib import Path
  4. from pycs import db
  5. from pycs.database.base import NamedBaseModel
  6. from pycs.database.util import commit_on_return
  7. class Model(NamedBaseModel):
  8. """
  9. DB model for ML models
  10. """
  11. description = db.Column(db.String)
  12. root_folder = db.Column(db.String, nullable=False, unique=True)
  13. supports_encoded = db.Column(db.String, nullable=False)
  14. # relationships to other models
  15. projects = db.relationship("Project", backref="model", lazy="dynamic")
  16. serialize_only: tuple = NamedBaseModel.serialize_only + (
  17. "description",
  18. "root_folder",
  19. )
  20. def serialize(self):
  21. result = super().serialize()
  22. result["supports"] = self.supports
  23. return result
  24. @classmethod
  25. def discover(cls, root: T.Union[Path, str], config_name: str = "configuration.json"):
  26. """
  27. searches for models under the given path
  28. and stores them in the database
  29. """
  30. for folder in Path(root).glob("*"):
  31. with open(folder / config_name, encoding='utf8') as config_file:
  32. config = json.load(config_file)
  33. # extract data
  34. name = config['name']
  35. description = config.get('description', None)
  36. supports = config['supports']
  37. model, _ = cls.get_or_create(root_folder=str(folder))
  38. model.name = name
  39. model.description = description
  40. model.supports = supports
  41. model.flush()
  42. db.session.commit()
  43. @property
  44. def supports(self) -> dict:
  45. """ getter for the 'supports' attribute """
  46. return json.loads(self.supports_encoded)
  47. @supports.setter
  48. def supports(self, value):
  49. """
  50. setter for the 'supports' attribute.
  51. The attribute is encoded property before assigned to the object.
  52. """
  53. if isinstance(value, str):
  54. self.supports_encoded = value
  55. elif isinstance(value, (dict, list)):
  56. self.supports_encoded = json.dumps(value)
  57. else:
  58. raise ValueError(f"Not supported type: {type(value)}")
  59. @commit_on_return
  60. def copy_to(self, name: str, root_folder: str):
  61. """ copies current model to another folder and updates the name """
  62. model, is_new = Model.get_or_create(root_folder=root_folder)
  63. model.name = name
  64. model.description = self.description
  65. model.supports_encoded = self.supports_encoded
  66. return model, is_new