Parcourir la source

updated tests

Dimitri Korsch il y a 5 ans
Parent
commit
c1b08d070f
7 fichiers modifiés avec 78 ajouts et 5 suppressions
  1. 1 1
      Makefile
  2. 1 1
      scripts/config.sh
  3. 7 0
      tests/__init__.py
  4. 3 3
      tests/configs.py
  5. 11 0
      tests/main.py
  6. 0 0
      tests/test_info.yml
  7. 55 0
      tests/test_parts.py

+ 1 - 1
Makefile

@@ -15,4 +15,4 @@ get_version:
 	@python setup.py --version
 
 run_tests:
-	@bash scripts/tests.sh .
+	@cd scripts && bash tests.sh

+ 1 - 1
scripts/config.sh

@@ -1,4 +1,4 @@
 source /home/korsch/.anaconda3/etc/profile.d/conda.sh
-conda activate chainer4
+conda activate ${ENV:-chainer5}
 
 PYTHON="python" #-m cProfile -o profile"

+ 7 - 0
tests/__init__.py

@@ -0,0 +1,7 @@
+from .test_annotations import *
+from .test_parts import *
+
+__all__ = [
+	"AnnotationTest",
+	"PartCropTest",
+]

+ 3 - 3
tests/configs.py

@@ -1,9 +1,9 @@
 import os
 import abc
 
-from os.path import *
+from pathlib import Path
 
 class config(abc.ABC):
-	BASE_DIR = abspath(os.environ.get("BASE_DIR", "."))
+	BASE_DIR = Path(os.environ.get("BASE_DIR", Path(__file__).parent))
 
-	INFO_FILE = join(BASE_DIR, "info_files", "test_info.yml")
+	INFO_FILE = str(BASE_DIR / "test_info.yml")

+ 11 - 0
tests/main.py

@@ -0,0 +1,11 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+import sys
+import unittest
+from pathlib import Path
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from tests import *
+
+unittest.main()

+ 0 - 0
scripts/info_files/test_info.yml → tests/test_info.yml


+ 55 - 0
tests/test_parts.py

@@ -0,0 +1,55 @@
+import unittest
+import numpy as np
+
+from cvdatasets.dataset.part.base import BasePart
+
+class PartCropTest(unittest.TestCase):
+
+
+	def setUp(self):
+		self.im = np.random.randn(300, 300, 3)
+
+	def tearDown(self):
+		pass
+
+	def _check_crop(self, cropped_im, _should):
+
+		self.assertIsNotNone(cropped_im,
+			"method crop should return something!")
+
+		self.assertIsInstance(cropped_im, type(self.im),
+			"result should have the same type as the input image")
+
+		crop_h, crop_w, _ = cropped_im.shape
+		h, w, _ = _should.shape
+		self.assertEqual(crop_h, h, "incorrect crop height")
+		self.assertEqual(crop_w, w, "incorrect crop width")
+
+		self.assertTrue((cropped_im == _should).all(),
+			"crop was incorret")
+
+
+
+	def test_bbox_part_crop(self):
+		_id, x, y, w, h = annotation = (0, 20, 20, 100, 100)
+
+		bbox = BasePart.new(self.im, annotation)
+
+		cropped_im = bbox.crop(self.im)
+
+		_should = self.im[y:y+h, x:x+w]
+		self._check_crop(cropped_im, _should)
+
+	def test_location_part_crop(self):
+		_id, center_x, center_y, _vis = annotation = (0, 50, 50, 1)
+
+		bbox = BasePart.new(self.im, annotation)
+
+		h, w, c = self.im.shape
+		for ratio in np.linspace(0.1, 0.3, num=9):
+			_h, _w = int(h * ratio), int(w * ratio)
+			cropped_im = bbox.crop(self.im, ratio=ratio)
+			x, y = center_x - _h // 2, center_y - _w // 2
+			_should = self.im[y : y + _h, x : x + _w]
+
+			self._check_crop(cropped_im, _should)