WebServer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import typing as T
  2. from glob import glob
  3. from pathlib import Path
  4. import eventlet
  5. import munch
  6. import socketio
  7. from flask import send_from_directory
  8. from pycs.database.LabelProvider import LabelProvider
  9. from pycs.database.Model import Model
  10. from pycs.frontend.endpoints.ListJobs import ListJobs
  11. from pycs.frontend.endpoints.ListLabelProviders import ListLabelProviders
  12. from pycs.frontend.endpoints.ListModels import ListModels
  13. from pycs.frontend.endpoints.ListProjects import ListProjects
  14. from pycs.frontend.endpoints.additional.FolderInformation import FolderInformation
  15. from pycs.frontend.endpoints.additional.Authenticate import Authenticate
  16. from pycs.frontend.endpoints.data.GetCroppedFile import GetCroppedFile
  17. from pycs.frontend.endpoints.data.GetFile import GetFile
  18. from pycs.frontend.endpoints.data.GetPreviousAndNextFile import GetPreviousAndNextFile
  19. from pycs.frontend.endpoints.data.GetResizedFile import GetResizedFile
  20. from pycs.frontend.endpoints.data.RemoveFile import RemoveFile
  21. from pycs.frontend.endpoints.data.UploadFile import UploadFile
  22. from pycs.frontend.endpoints.jobs.RemoveJob import RemoveJob
  23. from pycs.frontend.endpoints.labels.CreateLabel import CreateLabel
  24. from pycs.frontend.endpoints.labels.EditLabelName import EditLabelName
  25. from pycs.frontend.endpoints.labels.EditLabelParent import EditLabelParent
  26. from pycs.frontend.endpoints.labels.ListLabelTree import ListLabelTree
  27. from pycs.frontend.endpoints.labels.ListLabels import ListLabels
  28. from pycs.frontend.endpoints.labels.RemoveLabel import RemoveLabel
  29. from pycs.frontend.endpoints.pipelines.EstimateBoundingBox import EstimateBoundingBox
  30. from pycs.frontend.endpoints.pipelines.FitModel import FitModel
  31. from pycs.frontend.endpoints.pipelines.PredictBoundingBox import PredictBoundingBox
  32. from pycs.frontend.endpoints.pipelines.PredictFile import PredictFile
  33. from pycs.frontend.endpoints.pipelines.PredictModel import PredictModel
  34. from pycs.frontend.endpoints.projects.CreateProject import CreateProject
  35. from pycs.frontend.endpoints.projects.EditProjectDescription import EditProjectDescription
  36. from pycs.frontend.endpoints.projects.EditProjectName import EditProjectName
  37. from pycs.frontend.endpoints.projects.ExecuteExternalStorage import ExecuteExternalStorage
  38. from pycs.frontend.endpoints.projects.ExecuteLabelProvider import ExecuteLabelProvider
  39. from pycs.frontend.endpoints.projects.GetProjectModel import GetProjectModel
  40. from pycs.frontend.endpoints.projects.ListProjectCollections import ListProjectCollections
  41. from pycs.frontend.endpoints.projects.ListProjectFiles import ListProjectFiles
  42. from pycs.frontend.endpoints.projects.RemoveProject import RemoveProject
  43. from pycs.frontend.endpoints.results.ConfirmResult import ConfirmResult
  44. from pycs.frontend.endpoints.results.ConfirmAllResults import ConfirmAllResults
  45. from pycs.frontend.endpoints.results.CopyResults import CopyResults
  46. from pycs.frontend.endpoints.results.CreateResult import CreateResult
  47. from pycs.frontend.endpoints.results.EditResultData import EditResultData
  48. from pycs.frontend.endpoints.results.EditResultLabel import EditResultLabel
  49. from pycs.frontend.endpoints.results.GetProjectResults import GetProjectResults
  50. from pycs.frontend.endpoints.results.GetResults import GetResults
  51. from pycs.frontend.endpoints.results.RemoveResult import RemoveResult
  52. from pycs.frontend.endpoints.results.ResetResults import ResetResults
  53. from pycs.frontend.notifications.NotificationManager import NotificationManager
  54. from pycs.jobs.JobRunner import JobRunner
  55. from pycs.util.PipelineCache import PipelineCache
  56. class WebServer:
  57. """
  58. wrapper class for flask and socket.io which initializes most networking
  59. """
  60. index: Path = Path.cwd() / 'webui' / 'index.html'
  61. def __init__(self, app, htpasswd, settings: munch.Munch, discovery: bool = True):
  62. self.app = app
  63. self.htpasswd = htpasswd
  64. # initialize web server
  65. if self.is_production:
  66. app.logger.info('production build')
  67. # overwrite root path to serve index.html
  68. @self.app.route('/', methods=['GET'])
  69. def index():
  70. # pylint: disable=unused-variable
  71. return send_from_directory(str(self.index.parent), self.index.name)
  72. else:
  73. app.logger.info('development build')
  74. # set access control header to allow requests from Vue.js development server
  75. @self.app.after_request
  76. def after_request(response):
  77. # pylint: disable=unused-variable
  78. response.headers['Access-Control-Allow-Origin'] = 'http://localhost:8080'
  79. response.headers['Access-Control-Allow-Credentials'] = 'true'
  80. response.headers['Access-Control-Allow-Methods'] = 'POST, GET'
  81. response.headers['Access-Control-Allow-Headers'] = 'Authorization'
  82. return response
  83. # create service objects
  84. self.sio = socketio.Server(**self.sio_kwargs(settings.allowedOrigins))
  85. self.wsgi_app = socketio.WSGIApp(self.sio, app, static_files=self.static_files)
  86. self.host = settings.host
  87. self.port = settings.port
  88. # create notification manager
  89. self.jobs = JobRunner()
  90. self.pipelines = PipelineCache(self.jobs, settings.get("pipeline_cache_time"))
  91. self.notifications = NotificationManager(self.sio)
  92. self.jobs.on_create(self.notifications.create_job)
  93. self.jobs.on_start(self.notifications.edit_job)
  94. self.jobs.on_progress(self.notifications.edit_job)
  95. self.jobs.on_finish(self.notifications.edit_job)
  96. self.jobs.on_remove(self.notifications.remove_job)
  97. self.define_routes()
  98. if discovery:
  99. Model.discover("models/")
  100. LabelProvider.discover("labels/")
  101. def sio_kwargs(self, allowed_origins) -> T.Dict[str, T.Union[str, list]]:
  102. """keyword arguments for the socketio.Server depending on the mode"""
  103. kwargs: T.Dict[str, T.Union[str, list]] = dict(async_mode="eventlet")
  104. if self.is_production:
  105. if isinstance(allowed_origins, list) and len(allowed_origins) > 0:
  106. kwargs["cors_allowed_origins"] = allowed_origins
  107. else:
  108. kwargs["cors_allowed_origins"] = "*"
  109. return kwargs
  110. @property
  111. def is_production(self) -> bool:
  112. """property checking, whether the UI is built (production mode)
  113. or served by npm serve (development mode)"""
  114. return self.index.exists()
  115. @property
  116. def static_files(self) -> T.Optional[T.Dict[str, T.Union[str, dict]]]:
  117. """returns a dictionary of static files (production mode)
  118. or None (development mode)"""
  119. if not self.is_production:
  120. return None
  121. # find static files and folders
  122. static_files: T.Dict[str, T.Union[str, dict]] = {}
  123. for file_path in glob('webui/*'):
  124. file_path = file_path.replace('\\', '/')
  125. static_files[file_path[5:]] = file_path
  126. # separately add svg files and set their correct mime type
  127. for svg_path in glob('webui/img/*.svg'):
  128. svg_path = svg_path.replace('\\', '/')
  129. static_files[svg_path[5:]] = {'content_type': 'image/svg+xml', 'filename': svg_path}
  130. return static_files
  131. def define_routes(self):
  132. """ defines app routes """
  133. # authentication
  134. # additional
  135. self.app.add_url_rule(
  136. '/authenticate',
  137. view_func=self.htpasswd.required( Authenticate.as_view('authenticate') )
  138. )
  139. # additional
  140. self.app.add_url_rule(
  141. '/folder',
  142. view_func=self.htpasswd.required( FolderInformation.as_view('folder_information') )
  143. )
  144. # jobs
  145. self.app.add_url_rule(
  146. '/jobs',
  147. view_func=self.htpasswd.required( ListJobs.as_view('list_jobs', self.jobs) )
  148. )
  149. self.app.add_url_rule(
  150. '/jobs/<job_id>/remove',
  151. view_func=self.htpasswd.required( RemoveJob.as_view('remove_job', self.jobs) )
  152. )
  153. # models
  154. self.app.add_url_rule(
  155. '/models',
  156. view_func=self.htpasswd.required( ListModels.as_view('list_models') )
  157. )
  158. self.app.add_url_rule(
  159. '/projects/<int:project_id>/model',
  160. view_func=self.htpasswd.required( GetProjectModel.as_view('get_project_model') )
  161. )
  162. # labels
  163. self.app.add_url_rule(
  164. '/label_providers',
  165. view_func=self.htpasswd.required( ListLabelProviders.as_view('label_providers') )
  166. )
  167. self.app.add_url_rule(
  168. '/projects/<int:project_id>/labels',
  169. view_func=self.htpasswd.required( ListLabels.as_view('list_labels') )
  170. )
  171. self.app.add_url_rule(
  172. '/projects/<int:project_id>/labels/tree',
  173. view_func=self.htpasswd.required( ListLabelTree.as_view('list_label_tree') )
  174. )
  175. self.app.add_url_rule(
  176. '/projects/<int:project_id>/labels',
  177. view_func=self.htpasswd.required( CreateLabel.as_view('create_label',
  178. self.notifications) )
  179. )
  180. self.app.add_url_rule(
  181. '/projects/<int:project_id>/labels/<int:label_id>/remove',
  182. view_func=self.htpasswd.required( RemoveLabel.as_view('remove_label',
  183. self.notifications) )
  184. )
  185. self.app.add_url_rule(
  186. '/projects/<int:project_id>/labels/<int:label_id>/name',
  187. view_func=self.htpasswd.required( EditLabelName.as_view('edit_label_name',
  188. self.notifications) )
  189. )
  190. self.app.add_url_rule(
  191. '/projects/<int:project_id>/labels/<int:label_id>/parent',
  192. view_func=self.htpasswd.required( EditLabelParent.as_view('edit_label_parent',
  193. self.notifications) )
  194. )
  195. # collections
  196. self.app.add_url_rule(
  197. '/projects/<int:project_id>/collections',
  198. view_func=self.htpasswd.required( ListProjectCollections.as_view('list_collections') )
  199. )
  200. self.app.add_url_rule(
  201. '/projects/<int:project_id>/data/<int:collection_id>/<int:start>/<int:length>',
  202. view_func=self.htpasswd.required( ListProjectFiles.as_view('list_collection_files') )
  203. )
  204. # data
  205. self.app.add_url_rule(
  206. '/projects/<int:project_id>/data',
  207. view_func=self.htpasswd.required( UploadFile.as_view('upload_file',
  208. self.notifications) )
  209. )
  210. self.app.add_url_rule(
  211. '/projects/<int:project_id>/data',
  212. view_func=self.htpasswd.required( ListProjectFiles.as_view('list_all_files') )
  213. )
  214. self.app.add_url_rule(
  215. '/projects/<int:project_id>/data/<int:start>/<int:length>',
  216. view_func=self.htpasswd.required( ListProjectFiles.as_view('list_files') )
  217. )
  218. self.app.add_url_rule(
  219. '/data/<int:file_id>/remove',
  220. view_func=self.htpasswd.required( RemoveFile.as_view('remove_file',
  221. self.notifications) )
  222. )
  223. self.app.add_url_rule(
  224. '/data/<int:file_id>',
  225. view_func=GetFile.as_view('get_file')
  226. )
  227. self.app.add_url_rule(
  228. '/data/<int:file_id>/<resolution>',
  229. view_func=GetResizedFile.as_view('get_resized_file')
  230. )
  231. self.app.add_url_rule(
  232. '/data/<int:file_id>/<resolution>/<crop_box>',
  233. view_func=GetCroppedFile.as_view('get_cropped_file')
  234. )
  235. self.app.add_url_rule(
  236. '/data/<int:file_id>/previous_next',
  237. view_func=GetPreviousAndNextFile.as_view('get_previous_and_next_file')
  238. )
  239. # results
  240. self.app.add_url_rule(
  241. '/projects/<int:project_id>/results',
  242. view_func=self.htpasswd.required( GetProjectResults.as_view('get_project_results') )
  243. )
  244. self.app.add_url_rule(
  245. '/data/<int:file_id>/results',
  246. view_func=self.htpasswd.required( GetResults.as_view('get_results') )
  247. )
  248. self.app.add_url_rule(
  249. '/data/<int:file_id>/results',
  250. view_func=self.htpasswd.required( CreateResult.as_view('create_result',
  251. self.notifications) )
  252. )
  253. self.app.add_url_rule(
  254. '/data/<int:file_id>/copy_results',
  255. view_func=self.htpasswd.required( CopyResults.as_view('copy_results',
  256. self.notifications) )
  257. )
  258. self.app.add_url_rule(
  259. '/data/<int:file_id>/confirm_all',
  260. view_func=self.htpasswd.required( ConfirmAllResults.as_view('confirm_all',
  261. self.notifications) )
  262. )
  263. self.app.add_url_rule(
  264. '/data/<int:file_id>/reset',
  265. view_func=self.htpasswd.required( ResetResults.as_view('reset_results',
  266. self.notifications) )
  267. )
  268. self.app.add_url_rule(
  269. '/results/<int:result_id>/remove',
  270. view_func=self.htpasswd.required( RemoveResult.as_view('remove_result',
  271. self.notifications) )
  272. )
  273. self.app.add_url_rule(
  274. '/results/<int:result_id>/confirm',
  275. view_func=self.htpasswd.required( ConfirmResult.as_view('confirm_result',
  276. self.notifications) )
  277. )
  278. self.app.add_url_rule(
  279. '/results/<int:result_id>/label',
  280. view_func=self.htpasswd.required( EditResultLabel.as_view('edit_result_label',
  281. self.notifications) )
  282. )
  283. self.app.add_url_rule(
  284. '/results/<int:result_id>/data',
  285. view_func=self.htpasswd.required( EditResultData.as_view('edit_result_data',
  286. self.notifications) )
  287. )
  288. # projects
  289. self.app.add_url_rule(
  290. '/projects',
  291. view_func=self.htpasswd.required( ListProjects.as_view('list_projects') )
  292. )
  293. self.app.add_url_rule(
  294. '/projects',
  295. view_func=self.htpasswd.required( CreateProject.as_view('create_project',
  296. self.notifications, self.jobs) )
  297. )
  298. self.app.add_url_rule(
  299. '/projects/<int:project_id>/label_provider',
  300. view_func=self.htpasswd.required( ExecuteLabelProvider.as_view('execute_label_provider',
  301. self.notifications, self.jobs) )
  302. )
  303. self.app.add_url_rule(
  304. '/projects/<int:project_id>/external_storage',
  305. view_func=self.htpasswd.required(
  306. ExecuteExternalStorage.as_view('execute_external_storage',
  307. self.notifications, self.jobs) )
  308. )
  309. self.app.add_url_rule(
  310. '/projects/<int:project_id>/remove',
  311. view_func=self.htpasswd.required( RemoveProject.as_view('remove_project',
  312. self.notifications) )
  313. )
  314. self.app.add_url_rule(
  315. '/projects/<int:project_id>/name',
  316. view_func=self.htpasswd.required( EditProjectName.as_view('edit_project_name',
  317. self.notifications) )
  318. )
  319. self.app.add_url_rule(
  320. '/projects/<int:project_id>/description',
  321. view_func=self.htpasswd.required(
  322. EditProjectDescription.as_view('edit_project_description',
  323. self.notifications) )
  324. )
  325. # pipelines
  326. self.app.add_url_rule(
  327. '/projects/<int:project_id>/pipelines/fit',
  328. view_func=self.htpasswd.required( FitModel.as_view('fit_model', self.jobs,
  329. self.pipelines) )
  330. )
  331. self.app.add_url_rule(
  332. '/projects/<int:project_id>/pipelines/predict',
  333. view_func=self.htpasswd.required( PredictModel.as_view('predict_model',
  334. self.notifications, self.jobs, self.pipelines) )
  335. )
  336. self.app.add_url_rule(
  337. '/data/<int:file_id>/predict',
  338. view_func=self.htpasswd.required( PredictFile.as_view('predict_file',
  339. self.notifications, self.jobs, self.pipelines) )
  340. )
  341. self.app.add_url_rule(
  342. '/data/<int:file_id>/<int:bbox_id>/predict_bounding_box',
  343. view_func=self.htpasswd.required( PredictBoundingBox.as_view('predict_bounding_box',
  344. self.notifications, self.jobs, self.pipelines) )
  345. )
  346. self.app.add_url_rule(
  347. '/data/<int:file_id>/estimate',
  348. view_func=EstimateBoundingBox.as_view('estimate_result', self.notifications,
  349. self.jobs)
  350. )
  351. def run(self):
  352. """ start web server """
  353. self.pipelines.start()
  354. eventlet.wsgi.server(eventlet.listen((self.host, self.port)), self.wsgi_app)