6
0

Database.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import sqlite3
  2. from contextlib import closing
  3. from time import time
  4. from typing import Optional, Iterator
  5. from pycs import db
  6. from pycs.database.Collection import Collection
  7. from pycs.database.File import File
  8. from pycs.database.LabelProvider import LabelProvider
  9. from pycs.database.Model import Model
  10. from pycs.database.Project import Project
  11. from pycs.database.Result import Result
  12. from pycs.database.discovery.LabelProviderDiscovery import discover as discover_label_providers
  13. from pycs.database.discovery.ModelDiscovery import discover as discover_models
  14. class Database:
  15. """
  16. opens an sqlite database and allows to access several objects
  17. """
  18. def __init__(self, discovery: bool = True):
  19. """
  20. wrapper for some DB-related runctions. TODO: remove it!
  21. """
  22. if discovery:
  23. discover_models()
  24. discover_label_providers()
  25. def close(self):
  26. """
  27. close database file
  28. """
  29. return
  30. def copy(self):
  31. """
  32. Create a copy of this database object. This can be used to access the database
  33. from another thread. Table initialization and model and label provider discovery is
  34. disabled to speedup this function.
  35. :return: Database
  36. """
  37. return self
  38. def commit(self):
  39. """
  40. commit changes
  41. """
  42. db.session.commit()
  43. def get_object_by_id(self, table_name: str, identifier: int, cls):
  44. """
  45. create an object from cls and a row fetched from table_name and
  46. identified by the identifier
  47. :param table_name: table name
  48. :param identifier: unique identifier
  49. :param cls: class that is used to create the object
  50. :return: object of type cls
  51. """
  52. with closing(self.con.cursor()) as cursor:
  53. cursor.execute(f'SELECT * FROM {table_name} WHERE id = ?', [identifier])
  54. row = cursor.fetchone()
  55. if row is not None:
  56. return cls(self, row)
  57. return None
  58. def get_objects(self, table_name: str, cls):
  59. """
  60. get a list of all available objects in the table
  61. :param table_name: table name
  62. :param cls: class that is used to create the objects
  63. :return: list of object of type cls
  64. """
  65. with closing(self.con.cursor()) as cursor:
  66. cursor.execute(f'SELECT * FROM {table_name}')
  67. for row in cursor:
  68. yield cls(self, row)
  69. def models(self) -> Iterator[Model]:
  70. """
  71. get a list of all available models
  72. :return: iterator of models
  73. """
  74. return self.get_objects("models", Model)
  75. def model(self, identifier: int) -> Optional[Model]:
  76. """
  77. get a model using its unique identifier
  78. :param identifier: unique identifier
  79. :return: model
  80. """
  81. return self.get_object_by_id("models", identifier, Model)
  82. def label_providers(self) -> Iterator[LabelProvider]:
  83. """
  84. get a list of all available label providers
  85. :return: iterator over label providers
  86. """
  87. return self.get_objects("label_providers", LabelProvider)
  88. def label_provider(self, identifier: int) -> Optional[LabelProvider]:
  89. """
  90. get a label provider using its unique identifier
  91. :param identifier: unique identifier
  92. :return: label provider
  93. """
  94. return self.get_object_by_id("label_providers", identifier, LabelProvider)
  95. def projects(self) -> Iterator[Project]:
  96. """
  97. get a list of all available projects
  98. :return: iterator over projects
  99. """
  100. return self.get_objects("projects", Project)
  101. def project(self, identifier: int) -> Optional[Project]:
  102. """
  103. get a project using its unique identifier
  104. :param identifier: unique identifier
  105. :return: project
  106. """
  107. return self.get_object_by_id("projects", identifier, Project)
  108. def create_project(self,
  109. name: str,
  110. description: str,
  111. model: Model,
  112. label_provider: Optional[LabelProvider],
  113. root_folder: str,
  114. external_data: bool,
  115. data_folder: str):
  116. """
  117. insert a project into the database
  118. :param name: project name
  119. :param description: project description
  120. :param model: used model
  121. :param label_provider: used label provider (optional)
  122. :param root_folder: path to project folder
  123. :param external_data: whether an external data directory is used
  124. :param data_folder: path to data folder
  125. :return: created project
  126. """
  127. # prepare some values
  128. created = int(time())
  129. label_provider_id = label_provider.identifier if label_provider is not None else None
  130. # insert statement
  131. with closing(self.con.cursor()) as cursor:
  132. cursor.execute('''
  133. INSERT INTO projects (
  134. name, description, created, model, label_provider, root_folder, external_data,
  135. data_folder
  136. )
  137. VALUES (?, ?, ?, ?, ?, ?, ?, ?)
  138. ''', (name, description, created, model.identifier, label_provider_id, root_folder,
  139. external_data, data_folder))
  140. return self.project(cursor.lastrowid)
  141. def collection(self, identifier: int) -> Optional[Collection]:
  142. """
  143. get a collection using its unique identifier
  144. :param identifier: unique identifier
  145. :return: collection
  146. """
  147. return self.get_object_by_id("collections", identifier, Collection)
  148. def file(self, identifier) -> Optional[File]:
  149. """
  150. get a file using its unique identifier
  151. :param identifier: unique identifier
  152. :return: file
  153. """
  154. return self.get_object_by_id("files", identifier, File)
  155. def result(self, identifier) -> Optional[Result]:
  156. """
  157. get a result using its unique identifier
  158. :param identifier: unique identifier
  159. :return: result
  160. """
  161. return self.get_object_by_id("results", identifier, Result)