Eric Tröbs 4 年之前
父節點
當前提交
42aa4c75e8
共有 4 個文件被更改,包括 52 次插入5 次删除
  1. 1 1
      app.py
  2. 7 1
      pycs/frontend/WebServer.py
  3. 37 3
      pycs/pipeline/PipelineManager.py
  4. 7 0
      webui/src/components/window/side-navigation-bar.vue

+ 1 - 1
app.py

@@ -23,4 +23,4 @@ if __name__ == '__main__':
 
     # start web server
     print('- start web server')
-    web_server = WebServer(app_status)
+    web_server = WebServer(app_status, pipeline_manager)

+ 7 - 1
pycs/frontend/WebServer.py

@@ -10,13 +10,14 @@ from flask import Flask, make_response, send_from_directory, request
 from werkzeug import formparser
 
 from pycs.ApplicationStatus import ApplicationStatus
+from pycs.pipeline.PipelineManager import PipelineManager
 from pycs.util.GenericWrapper import GenericWrapper
 from pycs.util.ProgressFileWriter import ProgressFileWriter
 from pycs.util.RecursiveDictionary import set_recursive
 
 
 class WebServer:
-    def __init__(self, app_status: ApplicationStatus):
+    def __init__(self, app_status: ApplicationStatus, pipeline_manager: PipelineManager):
         # initialize web server
         if exists('webui/index.html'):
             print('production build')
@@ -203,6 +204,11 @@ class WebServer:
             file_name = target_object['id'] + target_object['extension']
             return send_from_directory(file_directory, file_name)
 
+        @self.__flask.route('/projects/<project_identifier>/data/<file_identifier>', methods=['POST'])
+        def run_prediction(project_identifier, file_identifier):
+            pipeline_manager.run(project_identifier, file_identifier)
+            return response()
+
         # finally start web server
         host = app_status['settings']['frontend']['host']
         port = app_status['settings']['frontend']['port']

+ 37 - 3
pycs/pipeline/PipelineManager.py

@@ -11,7 +11,41 @@ from pycs.pipeline.tf1.pipeline import Pipeline as TF1Pipeline
 class PipelineManager:
     def __init__(self, app_status: ApplicationStatus):
         self.app_status = app_status
-        app_status['projects'].subscribe(self.__update)
+        # app_status['projects'].subscribe(self.__update)
+
+    def run(self, project_identifier, file_identifier):
+        # find project
+        opened_projects = list(filter(lambda x: x['id'] == project_identifier, self.app_status['projects']))
+        if len(opened_projects) == 0:
+            return
+
+        current_project = opened_projects[0]
+
+        # find data object
+        data_objects = list(filter(lambda x: x['id'] == file_identifier, current_project['data']))
+        if len(data_objects) == 0:
+            return
+
+        target_objects = list(filter(lambda o: 'predictionResults' not in o.keys(), current_project['data']))
+        if len(target_objects) == 0:
+            return
+
+        # load pipeline
+        pipeline = tpool.execute(self.__load_pipeline, current_project['pipeline']['model-distribution'])
+
+        for target_object in target_objects:
+            print('>>>', target_object)
+
+            # create job list
+            # TODO update job progress
+            job = Job('detect-faces', current_project['id'], target_object)
+            result = tpool.execute(lambda p, j: p.execute(j), pipeline, job)
+            target_object['predictionResults'] = result.predictions
+
+            print('<<<', target_object)
+
+        # close pipeline
+        pipeline.close()
 
     def __update(self, data):
         # get current project path
@@ -40,8 +74,8 @@ class PipelineManager:
         # close pipeline
         pipeline.close()
 
-    def __load_pipeline(self, identifier):
-        model_distribution = self.app_status['models'][identifier]
+    def __load_pipeline(self, pipeline_identifier):
+        model_distribution = self.app_status['models'][pipeline_identifier]
 
         if model_distribution['mode'] == 'tf1':
             model_root = path.join(getcwd(), 'models', model_distribution['name'])

+ 7 - 0
webui/src/components/window/side-navigation-bar.vue

@@ -36,6 +36,13 @@
         <span>About PyCS</span>
       </div>
 
+      <div class="item"
+           :class="{inactive: !currentProject}"
+           @click="socket.post('/projects/' + currentProject.id + '/data/' + currentProject.data[0].id, {})">
+        <img src="@/assets/icons/info.svg">
+        <span>Run Prediction</span>
+      </div>
+
       <div v-if="window.wide"
            class="item"
            :class="{rotate: !collapsed}"