瀏覽代碼

added a new dataset

Dimitri Korsch 6 年之前
父節點
當前提交
b482d8d9d8
共有 5 個文件被更改,包括 129 次插入15 次删除
  1. 2 0
      cvdatasets/annotations/__init__.py
  2. 86 0
      cvdatasets/annotations/tigers.py
  3. 24 15
      scripts/display.py
  4. 9 0
      scripts/info_files/info.yml
  5. 8 0
      scripts/utils/parser.py

+ 2 - 0
cvdatasets/annotations/__init__.py

@@ -5,6 +5,7 @@ from .inat import INAT19_Annotations
 from .flowers import FLOWERS_Annotations
 from .dogs import DOGS_Annotations
 from .hed import HED_Annotations
+from .tigers import TIGERS_Annotations
 
 from .base import BaseAnnotations
 
@@ -19,6 +20,7 @@ class AnnotationType(BaseChoiceType):
 	DOGS = DOGS_Annotations
 	FLOWERS = FLOWERS_Annotations
 	HED = HED_Annotations
+	TIGERS = TIGERS_Annotations
 
 	INAT19 = INAT19_Annotations
 	INAT19_MINI = partial(INAT19_Annotations)

+ 86 - 0
cvdatasets/annotations/tigers.py

@@ -0,0 +1,86 @@
+import numpy as np
+import simplejson as json
+
+from cvdatasets.utils import _MetaInfo
+from .base import BaseAnnotations
+
+from os.path import join, isfile
+
+from sklearn.model_selection import StratifiedShuffleSplit
+
+class TIGERS_Annotations(BaseAnnotations):
+	name="tigers"
+
+	@property
+	def meta(self):
+		info = _MetaInfo(
+			images_folder="train",
+			images_file=join("atrw_anno_reid_train", "reid_list_train.csv"),
+			parts_file=join("atrw_anno_reid_train", "reid_keypoints_train.json"),
+		)
+
+		info.structure = [
+			[info.images_file, "_images"],
+			[info.parts_file, "_part_locs"],
+		]
+		return info
+
+
+	def _load_uuids(self):
+		self.uuids, self.cls_ids = [], []
+		self.uuid_to_idx, self.images = {}, []
+
+		for i, line in enumerate(self._images):
+			cls_id, imname = line.split(",")
+			self.uuids.append(str(i))
+			self.uuid_to_idx[str(i)] = i
+
+			self.images.append(imname)
+			self.cls_ids.append(int(cls_id))
+
+		self.uuids, self.images, self.cls_ids = map(np.array, [self.uuids, self.images, self.cls_ids])
+
+	def _load_labels(self):
+		self.classes, self.labels = np.unique(self.cls_ids, return_inverse=True)
+
+
+	def _load_parts(self):
+		keypoints = []
+		for image in self.images:
+			kpts = self._part_locs[image]
+			kpts = np.array(kpts).reshape(-1, 3)
+			kpt_idxs = np.arange(len(kpts))
+			kpts = np.hstack([kpt_idxs.reshape(-1, 1), kpts])
+
+			keypoints.append(kpts)
+
+		self.part_locs = np.array(keypoints)
+
+		n_parts = self.part_locs.shape[1]
+		self._part_names = [f"{i} part #{i}" for i in range(n_parts)]
+		self._load_part_names()
+
+	def _load_split(self, seed=4211):
+
+		splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=seed)
+
+		(train_IDs, test_IDs), = splitter.split(X=self.uuids, y=self.labels)
+
+		self.train_split = np.zeros_like(self.uuids, dtype=bool)
+
+		self.train_split[train_IDs] = 1
+		self.test_split = np.logical_not(self.train_split)
+
+	def parts(self, *args, **kwargs):
+		if self.has_parts:
+			return super(TIGERS_Annotations, self).parts(*args, **kwargs)
+		return None
+
+	def read_content(self, file, attr):
+		if not file.endswith(".json"):
+			return super(TIGERS_Annotations, self).read_content(file, attr)
+
+		with self._open(file) as f:
+			content = json.load(f)
+
+		setattr(self, attr, content)

+ 24 - 15
scripts/display.py

@@ -48,9 +48,16 @@ def main(args):
 
 	logging.info(f"Loaded {len(data)} {args.subset} images")
 
-	start = max(args.start, 0)
-	n_images = min(args.n_images, len(data) - start)
-	idxs = range(start, max(start, start + n_images))
+	if args.only_class >= 0:
+		logging.info(f"Showing only images from class {args.only_class}")
+		mask = data.labels == args.only_class
+		idxs = np.where(mask)[0]
+	else:
+		start = max(args.start, 0)
+		n_images = min(args.n_images, len(data) - start)
+		end = max(start, start + n_images)
+		logging.info(f"Showing only images {start} - {end}")
+		idxs = range(start, end)
 
 	for i in idxs:
 		im, parts, label = data[i]
@@ -61,28 +68,30 @@ def main(args):
 		axs[0].axis("off")
 		axs[0].set_title("Visible Parts")
 		axs[0].imshow(im)
+
 		if not args.crop_to_bb and not args.no_bboxes:
 			data.plot_bounding_box(i, axs[0])
-		parts.plot(im=im, ax=axs[0], ratio=data.ratio, linewidth=3)
 
 		# axs[1].axis("off")
 		# axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
 		# axs[1].imshow(parts.reveal(im, ratio=data.ratio))
 
-		if data.uniform_parts:
-			crop_names = None
-		else:
-			crop_names = list(data._annot.part_names.values())
+		if not args.no_parts:
+			parts.plot(im=im, ax=axs[0], ratio=data.ratio, linewidth=3)
+			if data.uniform_parts:
+				crop_names = None
+			else:
+				crop_names = list(data._annot.part_names.values())
 
-		part_crops = parts.visible_crops(im, ratio=data.ratio)
-		if args.rnd:
-			parts.invert_selection()
-			action_crops = parts.visible_crops(im, ratio=data.ratio)
+			part_crops = parts.visible_crops(im, ratio=data.ratio)
+			if args.rnd:
+				parts.invert_selection()
+				action_crops = parts.visible_crops(im, ratio=data.ratio)
 
-		plot_crops(part_crops, f"{args.parts}: Selected parts", names=crop_names)
+			plot_crops(part_crops, f"{args.parts}: Selected parts", names=crop_names)
 
-		if args.rnd:
-			plot_crops(action_crops, f"{args.parts}: Actions", names=crop_names)
+			if args.rnd:
+				plot_crops(action_crops, f"{args.parts}: Actions", names=crop_names)
 
 		plt.show()
 		plt.close()

+ 9 - 0
scripts/info_files/info.yml

@@ -99,6 +99,11 @@ DATASETS:
     annotations: "patches224x224"
     n_classes: 2
 
+  TIGERS:         &tigers
+    folder: tigers
+    annotations: "ORIGINAL"
+    n_classes: 107
+
 ############ Existing Part Annotations and Part Features
 ### feature file name composition:
 # ${BASE_DIR}/${DATA_DIR}/${DATASETS:folder}/${PART_TYPES:annotations}/features
@@ -190,6 +195,10 @@ PARTS:
     <<: *hed
     <<: *parts_global
 
+  TIGERS_GLOBAL:
+    <<: *tigers
+    <<: *parts_global
+
   #### With Parts Annotations
 
 

+ 8 - 0
scripts/utils/parser.py

@@ -32,6 +32,10 @@ def parse_args():
 			help="Number of images to display",
 			type=int, default=10),
 
+		Arg("--only_class",
+			help="display only the given class",
+			type=int, default=-1),
+
 
 		Arg("--rnd",
 			help="select random subset of present parts",
@@ -41,6 +45,10 @@ def parse_args():
 			help="Do not display bounding boxes",
 			action="store_true"),
 
+		Arg("--no_parts",
+			help="Do not display parts",
+			action="store_true"),
+
 		Arg("--crop_to_bb",
 			help="Crop image to the bounding box",
 			action="store_true"),