6
0

Project.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. import os
  2. import shutil
  3. import typing as T
  4. import warnings
  5. from datetime import datetime
  6. from pycs import app
  7. from pycs import db
  8. from pycs.database.base import NamedBaseModel
  9. from pycs.database.Collection import Collection
  10. from pycs.database.File import File
  11. from pycs.database.Label import Label
  12. from pycs.database.util import commit_on_return
  13. class Project(NamedBaseModel):
  14. """ DB Model for projects """
  15. description = db.Column(db.String)
  16. created = db.Column(db.DateTime, default=datetime.utcnow,
  17. index=True, nullable=False)
  18. model_id = db.Column(
  19. db.Integer,
  20. db.ForeignKey("model.id", ondelete="SET NULL"))
  21. label_provider_id = db.Column(
  22. db.Integer,
  23. db.ForeignKey("label_provider.id", ondelete="SET NULL"))
  24. root_folder = db.Column(db.String, nullable=False, unique=True)
  25. external_data = db.Column(db.Boolean, nullable=False)
  26. data_folder = db.Column(db.String, nullable=False)
  27. # contraints
  28. __table_args__ = ()
  29. # relationships to other models
  30. files = db.relationship(
  31. "File",
  32. backref="project",
  33. lazy="dynamic",
  34. passive_deletes=True,
  35. )
  36. labels = db.relationship(
  37. "Label",
  38. backref="project",
  39. lazy="dynamic",
  40. passive_deletes=True,
  41. )
  42. collections = db.relationship(
  43. "Collection",
  44. backref="project",
  45. lazy="dynamic",
  46. passive_deletes=True,
  47. )
  48. serialize_only = NamedBaseModel.serialize_only + (
  49. "created",
  50. "description",
  51. "model_id",
  52. "label_provider_id",
  53. "root_folder",
  54. "external_data",
  55. "data_folder",
  56. )
  57. @commit_on_return
  58. def delete(self) -> T.Tuple[dict, dict]:
  59. # pylint: disable=unexpected-keyword-arg
  60. dump = super().delete(commit=False)
  61. model_dump = {}
  62. if self.model_id is not None:
  63. # pylint: disable=unexpected-keyword-arg
  64. model_dump = self.model.delete(commit=False)
  65. if os.path.exists(self.root_folder):
  66. # remove from file system
  67. shutil.rmtree(self.root_folder)
  68. return dump, model_dump
  69. def label(self, identifier: int) -> T.Optional[Label]:
  70. """
  71. get a label using its unique identifier
  72. :param identifier: unique identifier
  73. :return: label
  74. """
  75. return self.labels.filter(Label.id == identifier).one_or_none()
  76. def label_by_reference(self, reference: str) -> T.Optional[Label]:
  77. """
  78. get a label using its reference string
  79. :param reference: reference string
  80. :return: label
  81. """
  82. return self.labels.filter(Label.reference == reference).one_or_none()
  83. def file(self, identifier: int) -> T.Optional[Label]:
  84. """
  85. get a file using its unique identifier
  86. :param identifier: unique identifier
  87. :return: file
  88. """
  89. return self.files.filter(File.id == identifier).one_or_none()
  90. def label_tree(self) -> T.List[Label]:
  91. """
  92. get a list of root labels associated with this project
  93. :return: list of labels
  94. """
  95. warnings.warn("Check performance of this method!")
  96. # pylint: disable=no-member
  97. return self.labels.filter(Label.parent_id.is_(None)).all()
  98. def label_tree_original(self):
  99. """
  100. get a list of root labels associated with this project
  101. :return: list of labels
  102. """
  103. raise NotImplementedError
  104. # pylint: disable=unreachable
  105. # pylint: disable=pointless-string-statement
  106. """
  107. with closing(self.database.con.cursor()) as cursor:
  108. cursor.execute('''
  109. WITH RECURSIVE
  110. tree AS (
  111. SELECT labels.* FROM labels
  112. WHERE project = ? AND parent IS NULL
  113. UNION ALL
  114. SELECT labels.* FROM labels
  115. JOIN tree ON labels.parent = tree.id
  116. )
  117. SELECT * FROM tree
  118. ''', [self.id])
  119. result = []
  120. lookup = {}
  121. for row in cursor.fetchall():
  122. label = TreeNodeLabel(self.database, row)
  123. lookup[label.id] = label
  124. if label.parent_id is None:
  125. result.append(label)
  126. else:
  127. lookup[label.parent_id].children.append(label)
  128. return result
  129. """
  130. def collection(self, identifier: int) -> T.Optional[Collection]:
  131. """
  132. get a collection using its unique identifier
  133. :param identifier: unique identifier
  134. :return: collection
  135. """
  136. return self.collections.filter(Collection.id == identifier).one_or_none()
  137. def collection_by_reference(self, reference: str) -> T.Optional[Collection]:
  138. """
  139. get a collection using its unique identifier
  140. :param identifier: unique identifier
  141. :return: collection
  142. """
  143. return self.collections.filter(Collection.reference == reference).one_or_none()
  144. @commit_on_return
  145. def create_label(self, name: str,
  146. reference: str = None,
  147. parent: T.Optional[T.Union[int, str, Label]] = None,
  148. hierarchy_level: str = None) -> T.Tuple[T.Optional[Label], bool]:
  149. """
  150. create a label for this project. If there is already a label with the same reference
  151. in the database its name is updated.
  152. :param name: label name
  153. :param reference: label reference
  154. :param parent: parent label. Either a reference string, a Label id or a Label instance
  155. :param hierarchy_level: hierarchy level name
  156. :return: created or edited label, insert
  157. """
  158. label = None
  159. is_new = False
  160. if reference is not None:
  161. label = Label.query.filter_by(project_id=self.id, reference=reference).one_or_none()
  162. if label is None:
  163. label = Label.new(commit=False, project_id=self.id, reference=reference)
  164. is_new = True
  165. label.set_name(name, commit=False)
  166. label.set_parent(parent, commit=False)
  167. label.hierarchy_level = hierarchy_level
  168. return label, is_new
  169. @commit_on_return
  170. def bulk_create_labels(self, labels: T.List[T.Dict], clean_old_labels: bool = True):
  171. """
  172. Inserts a all labels at once.
  173. :raises:
  174. - AssertionError if project_id and reference are not unique
  175. - ValueError if a cycle in the hierarchy is found
  176. """
  177. if clean_old_labels:
  178. self.labels.delete()
  179. for label in labels:
  180. label["project_id"] = self.id
  181. self.__check_labels(labels)
  182. app.logger.info(f"Inserting {len(labels):,d} labels")
  183. db.engine.execute(Label.__table__.insert(), labels)
  184. self.__set_parents(labels)
  185. return labels
  186. def __set_parents(self, labels):
  187. """ after the bul insert, we need to set correct parent_ids """
  188. app.logger.info("Setting parents of the labels")
  189. self.flush()
  190. for label in labels:
  191. if label["parent"] is None:
  192. continue
  193. label_obj = self.label_by_reference(label["reference"])
  194. parent_label_obj = self.label_by_reference(label["parent"])
  195. label_obj.parent_id = parent_label_obj.id
  196. def __check_labels(self, labels):
  197. """ check labels for unique keys and cycles """
  198. unique_keys = dict()
  199. for label in labels:
  200. key = (label["project_id"], label["reference"])
  201. assert key not in unique_keys, \
  202. f"{key} was not unique: ({label=} vs {unique_keys[key]=})!"
  203. unique_keys[key] = label
  204. # pylint: disable=too-many-arguments
  205. @commit_on_return
  206. def create_collection(self,
  207. reference: str,
  208. name: str,
  209. description: str,
  210. position: int,
  211. autoselect: bool) -> T.Tuple[Collection, bool]:
  212. """
  213. create a new collection associated with this project
  214. :param reference: collection reference string
  215. :param name: collection name
  216. :param description: collection description
  217. :param position: position in menus
  218. :param autoselect: automatically select this collection on session load
  219. :return: collection object, insert
  220. """
  221. collection, is_new = Collection.get_or_create(
  222. project_id=self.id, reference=reference)
  223. collection.name = name
  224. collection.description = description
  225. collection.position = position
  226. collection.autoselect = autoselect
  227. return collection, is_new
  228. # pylint: disable=too-many-arguments
  229. @commit_on_return
  230. def add_file(self,
  231. uuid: str,
  232. file_type: str,
  233. name: str,
  234. extension: str,
  235. size: int,
  236. filename: str,
  237. frames: int = None,
  238. fps: float = None) -> T.Tuple[File, bool]:
  239. """
  240. add a file to this project
  241. :param uuid: unique identifier which is used for temporary files
  242. :param file_type: file type (either image or video)
  243. :param name: file name
  244. :param extension: file extension
  245. :param size: file size
  246. :param filename: actual name in filesystem
  247. :param frames: frame count
  248. :param fps: frames per second
  249. :return: file
  250. """
  251. path = os.path.join(self.data_folder, f"{filename}{extension}")
  252. file, is_new = File.get_or_create(
  253. project_id=self.id, path=path)
  254. file.uuid = uuid
  255. file.type = file_type
  256. file.name = name
  257. file.extension = extension
  258. file.size = size
  259. file.frames = frames
  260. file.fps = fps
  261. return file, is_new
  262. def get_files(self, *filters, offset: int = 0, limit: int = -1) -> T.List[File]:
  263. """
  264. get an iterator of files associated with this project
  265. :param offset: file offset
  266. :param limit: file limit
  267. :return: iterator of files
  268. """
  269. return self.files.filter(*filters).order_by(File.id).offset(offset).limit(limit)
  270. def _files_without_results(self):
  271. """
  272. get files without any results
  273. :return: a query object
  274. """
  275. # pylint: disable=no-member
  276. return self.files.filter(~File.results.any())
  277. def count_files_without_results(self) -> int:
  278. """
  279. count files without associated results
  280. :return: count
  281. """
  282. return self._files_without_results().count()
  283. def files_without_results(self) -> T.List[File]:
  284. """
  285. get a list of files without associated results
  286. :return: list of files
  287. """
  288. return self._files_without_results().all()
  289. def _files_without_collection(self, offset: int = 0, limit: int = -1):
  290. """
  291. get files without a collection
  292. :return: a query object
  293. """
  294. # pylint: disable=no-member
  295. return self.get_files(File.collection_id.is_(None), offset=offset, limit=limit)
  296. def files_without_collection(self, offset: int = 0, limit: int = -1) -> T.List[File]:
  297. """
  298. get a list of files without a collection
  299. :return: list of files
  300. """
  301. return self._files_without_collection(offset=offset, limit=limit).all()
  302. def count_files_without_collection(self) -> int:
  303. """
  304. count files associated with this project but without a collection
  305. :return: count
  306. """
  307. return self._files_without_collection().count()