|
@@ -84,16 +84,16 @@ class UniformPartMixin(BaseMixin):
|
|
|
|
|
|
class RandomBlackOutMixin(BaseMixin):
|
|
|
|
|
|
- def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
|
|
|
+ def __init__(self, seed=None, rnd_select=False, blackout_parts=None, *args, **kwargs):
|
|
|
super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
|
|
|
self.rnd = np.random.RandomState(seed)
|
|
|
self.rnd_select = rnd_select
|
|
|
- self.n_parts = n_parts
|
|
|
+ self.blackout_parts = blackout_parts
|
|
|
|
|
|
def get_example(self, i):
|
|
|
im_obj = super(RandomBlackOutMixin, self).get_example(i)
|
|
|
if self.rnd_select:
|
|
|
- return im_obj.select_random_parts(rnd=self.rnd, n_parts=self.n_parts)
|
|
|
+ return im_obj.select_random_parts(rnd=self.rnd, n_parts=self.blackout_parts)
|
|
|
return im_obj
|
|
|
|
|
|
|