Przeglądaj źródła

added method for profiling of image processing

Dimitri Korsch 3 lat temu
rodzic
commit
296db61abe

+ 10 - 2
cvfinetune/finetuner/base.py

@@ -1,5 +1,7 @@
 import chainer
 import logging
+import numpy as np
+import typing as T
 
 from cvfinetune.finetuner import mixins
 
@@ -15,9 +17,11 @@ class DefaultFinetuner(
 
 	"""
 
-	def __init__(self, *args, config: dict = {},  gpu = [-1], **kwargs):
-		super().__init__(*args, **kwargs)
+	def __init__(self, *args, config: dict = {}, gpu = [-1],
+		         seed: T.Optional[int] = None, **kwargs):
 
+		super().__init__(*args, **kwargs)
+		self.init_random_state(seed=seed)
 		self.init_experiment(config=config)
 
 		self.gpu_config(gpu)
@@ -39,6 +43,10 @@ class DefaultFinetuner(
 		msg = msg or f"<{type(self).__name__}> {attr_name} attribute was not initialized!"
 		assert hasattr(self, attr_name), msg
 
+	def init_random_state(self, seed: T.Optional[int] = None):
+		logging.info(f"Using seed: {seed}")
+		self.rnd = np.random.RandomState(seed)
+
 	def init_device(self):
 		self.device = chainer.get_device(self.device_id)
 		self.device.use()

+ 14 - 0
cvfinetune/finetuner/mixins/dataset.py

@@ -1,6 +1,7 @@
 import abc
 import logging
 import typing as T
+import chainer
 
 from cvdatasets import AnnotationType
 from cvdatasets import AnnotationArgs
@@ -110,3 +111,16 @@ class _DatasetMixin(BaseMixin):
         return ds
 
 
+    def profile_images(self, train: bool = True):
+        ds = self.train_data if train else self.val_data
+        subset_name = "training" if train else "validation"
+
+        if isinstance(ds, chainer.datasets.SubDataset):
+            ds = ds._dataset
+
+
+        logging.info(f"Profiling {subset_name} image processing: ")
+        with ds.enable_img_profiler():
+            ds[self.rnd.randint(len(ds))]
+
+