|
@@ -1,10 +1,13 @@
|
|
|
import abc
|
|
|
-import logging
|
|
|
import copy
|
|
|
+import logging
|
|
|
import numpy as np
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
from collections import defaultdict
|
|
|
+from typing import Union
|
|
|
+
|
|
|
+from cvdatasets.annotation.files import AnnotationFiles
|
|
|
from cvdatasets.utils.decorators import only_with_info
|
|
|
|
|
|
class PartsMixin(abc.ABC):
|
|
@@ -25,18 +28,19 @@ class PartsMixin(abc.ABC):
|
|
|
|
|
|
super(PartsMixin, self).__init__(*args, **kwargs)
|
|
|
|
|
|
- def read_annotation_files(self):
|
|
|
-
|
|
|
+ def read_annotation_files(self) -> AnnotationFiles:
|
|
|
files = super(PartsMixin, self).read_annotation_files()
|
|
|
+ logging.debug("Adding part annotation files")
|
|
|
files.load_files(
|
|
|
part_locs=("parts/part_locs.txt", True),
|
|
|
part_names=("parts/parts.txt", True),
|
|
|
)
|
|
|
+
|
|
|
return files
|
|
|
|
|
|
@property
|
|
|
@only_with_info
|
|
|
- def dataset_info(self):
|
|
|
+ def dataset_info(self) -> dict:
|
|
|
ds_info = super(PartsMixin, self).dataset_info
|
|
|
if self.part_type is not None:
|
|
|
parts_key = f"{self.dataset_key}_{self.part_type}"
|
|
@@ -50,7 +54,7 @@ class PartsMixin(abc.ABC):
|
|
|
|
|
|
return ds_info
|
|
|
|
|
|
- def check_dataset_kwargs(self, subset, **kwargs):
|
|
|
+ def check_dataset_kwargs(self, subset, **kwargs) -> dict:
|
|
|
if self.dataset_info is None:
|
|
|
return kwargs
|
|
|
|
|
@@ -64,20 +68,20 @@ class PartsMixin(abc.ABC):
|
|
|
return super(PartsMixin, self).check_dataset_kwargs(subset, **new_kwargs)
|
|
|
|
|
|
@property
|
|
|
- def has_parts(self):
|
|
|
+ def has_parts(self) -> bool:
|
|
|
return self.files.part_locs is not None
|
|
|
|
|
|
@property
|
|
|
- def has_part_names(self):
|
|
|
+ def has_part_names(self) -> bool:
|
|
|
return self.files.part_names is not None
|
|
|
|
|
|
- def parse_annotations(self):
|
|
|
+ def parse_annotations(self) -> None:
|
|
|
super(PartsMixin, self).parse_annotations()
|
|
|
|
|
|
if self.has_parts:
|
|
|
self._parse_parts()
|
|
|
|
|
|
- def _parse_parts(self):
|
|
|
+ def _parse_parts(self) -> None:
|
|
|
logging.debug("Parsing part annotations")
|
|
|
assert self.has_parts, \
|
|
|
"Part locations were not loaded!"
|
|
@@ -96,7 +100,7 @@ class PartsMixin(abc.ABC):
|
|
|
if self.has_part_names:
|
|
|
self._parse_part_names()
|
|
|
|
|
|
- def _parse_part_names(self):
|
|
|
+ def _parse_part_names(self) -> None:
|
|
|
self.part_names.clear()
|
|
|
self.part_name_list.clear()
|
|
|
|
|
@@ -105,7 +109,7 @@ class PartsMixin(abc.ABC):
|
|
|
self.part_names[int(part_idx)] = name
|
|
|
self.part_name_list.append(name)
|
|
|
|
|
|
- def parts(self, uuid):
|
|
|
+ def parts(self, uuid) -> Union[np.ndarray, None]:
|
|
|
if self.has_parts:
|
|
|
return self.part_locs[self.uuid_to_idx[uuid]].copy()
|
|
|
|