|
@@ -1,5 +1,7 @@
|
|
|
import chainer
|
|
|
import logging
|
|
|
+import numpy as np
|
|
|
+import typing as T
|
|
|
|
|
|
from cvfinetune.finetuner import mixins
|
|
|
|
|
@@ -15,9 +17,11 @@ class DefaultFinetuner(
|
|
|
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, *args, config: dict = {}, gpu = [-1], **kwargs):
|
|
|
- super().__init__(*args, **kwargs)
|
|
|
+ def __init__(self, *args, config: dict = {}, gpu = [-1],
|
|
|
+ seed: T.Optional[int] = None, **kwargs):
|
|
|
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+ self.init_random_state(seed=seed)
|
|
|
self.init_experiment(config=config)
|
|
|
|
|
|
self.gpu_config(gpu)
|
|
@@ -39,6 +43,10 @@ class DefaultFinetuner(
|
|
|
msg = msg or f"<{type(self).__name__}> {attr_name} attribute was not initialized!"
|
|
|
assert hasattr(self, attr_name), msg
|
|
|
|
|
|
+ def init_random_state(self, seed: T.Optional[int] = None):
|
|
|
+ logging.info(f"Using seed: {seed}")
|
|
|
+ self.rnd = np.random.RandomState(seed)
|
|
|
+
|
|
|
def init_device(self):
|
|
|
self.device = chainer.get_device(self.device_id)
|
|
|
self.device.use()
|