Eric Tröbs 4 years ago
parent
commit
2cade612c7
3 changed files with 62 additions and 55 deletions
  1. 1 2
      pycs/pipeline/PipelineManager.py
  2. 58 0
      pycs/projects/Project.py
  3. 3 53
      pycs/projects/ProjectManager.py

+ 1 - 2
pycs/pipeline/PipelineManager.py

@@ -3,11 +3,10 @@ from os import path
 from eventlet import tpool
 from eventlet import tpool
 
 
 from pycs.pipeline.Job import Job
 from pycs.pipeline.Job import Job
-from pycs.projects.Project import Project
 
 
 
 
 class PipelineManager:
 class PipelineManager:
-    def __init__(self, project: Project):
+    def __init__(self, project):
         code_path = path.join(project['model']['path'], project['model']['code']['module'])
         code_path = path.join(project['model']['path'], project['model']['code']['module'])
         module_name = code_path.replace('/', '.').replace('\\', '.')
         module_name = code_path.replace('/', '.').replace('\\', '.')
         class_name = project['model']['code']['class']
         class_name = project['model']['code']['class']

+ 58 - 0
pycs/projects/Project.py

@@ -2,13 +2,21 @@ from json import load
 from os import path
 from os import path
 from uuid import uuid1
 from uuid import uuid1
 
 
+from eventlet import spawn_after
+
 from pycs.observable import ObservableDict
 from pycs.observable import ObservableDict
+from pycs.pipeline.PipelineManager import PipelineManager
 from pycs.projects.MediaFile import MediaFile
 from pycs.projects.MediaFile import MediaFile
 from pycs.util.RecursiveDictionary import set_recursive
 from pycs.util.RecursiveDictionary import set_recursive
 
 
 
 
 class Project(ObservableDict):
 class Project(ObservableDict):
+    DEFAULT_PIPELINE_TIMEOUT = 120
+
     def __init__(self, obj: dict, parent):
     def __init__(self, obj: dict, parent):
+        self.pipeline_manager = None
+        self.quit_pipeline_thread = None
+
         # ensure all required object keys are available
         # ensure all required object keys are available
         for key in ['data', 'labels', 'jobs']:
         for key in ['data', 'labels', 'jobs']:
             if key not in obj.keys():
             if key not in obj.keys():
@@ -75,3 +83,53 @@ class Project(ObservableDict):
 
 
         # remove label from list
         # remove label from list
         del self['labels'][identifier]
         del self['labels'][identifier]
+
+    def predict(self, identifiers):
+        # create pipeline
+        pipeline = self.__create_pipeline()
+
+        # run jobs
+        for file_id in identifiers:
+            if file_id in self['data'].keys():
+                pipeline.run(self['data'][file_id])
+
+        # schedule timeout thread
+        self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
+
+    def fit(self):
+        # create pipeline
+        pipeline = self.__create_pipeline()
+
+        # run fit
+        pipeline.fit()
+
+        # schedule timeout thread
+        self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
+
+    def __create_pipeline(self):
+        # abort pipeline termination
+        self.__quit_pipeline_thread()
+
+        # create pipeline if it does not exist already
+        if self.pipeline_manager is None:
+            self.pipeline_manager = PipelineManager(self)
+
+        return self.pipeline_manager
+
+    def __quit_pipeline(self):
+        if self.pipeline_manager is not None:
+            self.pipeline_manager.close()
+            self.pipeline_manager = None
+            self.quit_pipeline_thread = None
+
+    def __create_quit_pipeline_thread(self):
+        # abort pipeline termination
+        self.__quit_pipeline_thread()
+
+        # create new thread
+        self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
+
+    def __quit_pipeline_thread(self):
+        if self.quit_pipeline_thread is not None:
+            self.quit_pipeline_thread.cancel()
+            self.quit_pipeline_thread = None

+ 3 - 53
pycs/projects/ProjectManager.py

@@ -5,21 +5,13 @@ from shutil import rmtree, copytree
 from time import time
 from time import time
 from uuid import uuid1
 from uuid import uuid1
 
 
-from eventlet import spawn_after
-
 from pycs import ApplicationStatus
 from pycs import ApplicationStatus
 from pycs.observable import ObservableDict
 from pycs.observable import ObservableDict
-from pycs.pipeline.PipelineManager import PipelineManager
 from pycs.projects.Project import Project
 from pycs.projects.Project import Project
 
 
 
 
 class ProjectManager(ObservableDict):
 class ProjectManager(ObservableDict):
-    DEFAULT_PIPELINE_TIMEOUT = 120
-
     def __init__(self, app_status: ApplicationStatus):
     def __init__(self, app_status: ApplicationStatus):
-        self.pipeline_manager = None
-        self.quit_pipeline_thread = None
-
         # TODO create projects folder if it does not exist
         # TODO create projects folder if it does not exist
         self.app_status = app_status
         self.app_status = app_status
 
 
@@ -96,27 +88,8 @@ class ProjectManager(ObservableDict):
 
 
         project = self[uuid]
         project = self[uuid]
 
 
-        # abort pipeline termination
-        if self.quit_pipeline_thread is not None:
-            self.quit_pipeline_thread.cancel()
-            self.quit_pipeline_thread = None
-
-        # create pipeline if it does not exist already
-        if self.pipeline_manager is None:
-            self.pipeline_manager = PipelineManager(project)
-
-        # run jobs
-        for file_id in identifiers:
-            if file_id in project['data'].keys():
-                self.pipeline_manager.run(project['data'][file_id])
-
-        # quit timeout thread
-        if self.quit_pipeline_thread is not None:
-            self.quit_pipeline_thread.cancel()
-            self.quit_pipeline_thread = None
-
-        # schedule timeout thread
-        self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
+        # run prediction
+        project.predict(identifiers)
 
 
     def fit(self, uuid):
     def fit(self, uuid):
         # abort if uuid is no valid key
         # abort if uuid is no valid key
@@ -125,28 +98,5 @@ class ProjectManager(ObservableDict):
 
 
         project = self[uuid]
         project = self[uuid]
 
 
-        # abort pipeline termination
-        if self.quit_pipeline_thread is not None:
-            self.quit_pipeline_thread.cancel()
-            self.quit_pipeline_thread = None
-
-        # create pipeline if it does not exist already
-        if self.pipeline_manager is None:
-            self.pipeline_manager = PipelineManager(project)
-
         # run fit
         # run fit
-        self.pipeline_manager.fit()
-
-        # quit timeout thread
-        if self.quit_pipeline_thread is not None:
-            self.quit_pipeline_thread.cancel()
-            self.quit_pipeline_thread = None
-
-        # schedule timeout thread
-        self.quit_pipeline_thread = spawn_after(self.DEFAULT_PIPELINE_TIMEOUT, self.__quit_pipeline)
-
-    def __quit_pipeline(self):
-        if self.pipeline_manager is not None:
-            self.pipeline_manager.close()
-            self.pipeline_manager = None
-            self.quit_pipeline_thread = None
+        project.fit()