Эх сурвалжийг харах

added python-only version of the classifier evaluation

Dimitri Korsch 4 жил өмнө
parent
commit
9bf5b8cbce

+ 94 - 0
main.py

@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+import os
+os.environ["OMP_NUM_THREADS"] = "2"
+
+import chainer
+import cupy
+import numpy as np
+
+from chainer.backends import cuda
+from chainer.dataset import concat_examples
+from chainer.dataset.iterator import Iterator
+from cvargparse import Arg
+from cvargparse import BaseParser
+from cvmodelz.models import ModelFactory
+from tqdm.auto import tqdm
+
+from moth_classifier import dataset as ds
+from moth_classifier import parser
+
+def evaluate(model: chainer.Chain, iterator: Iterator, device_id: int = -1) -> float:
+	if device_id >= 0:
+		device = cuda.get_device_from_id(device_id)
+		device.use()
+	else:
+		# for CPU mode
+		device = device_id
+
+	model.to_device(device)
+
+	n_batches = int(np.ceil(len(iterator.dataset) / iterator.batch_size))
+
+	preds = []
+	labels = []
+
+	iterator.reset()
+	for batch in tqdm(iterator, total=n_batches):
+		X, y = concat_examples(batch, device=device)
+
+		# shape is (batch_size, #classes)
+		logits = model(X)
+
+		logits.to_cpu()
+
+		# get the class ID with the highest score
+		preds.extend(logits.array.argmax(axis=-1))
+		labels.extend(chainer.cuda.to_cpu(y))
+
+	return np.mean(np.array(preds) == np.array(labels))
+
+def load_model(weights: str):
+
+	model = ModelFactory.new("cvmodelz.InceptionV3", n_classes=200)
+	model.load(weights, path="model/")
+	return model
+
+
+def main(args):
+	if args.debug:
+		chainer.set_debug(args.debug)
+		print("=== Chainer info ===")
+		chainer.config.show()
+		print("=== CuPy info ===")
+		cupy.show_config()
+
+	model = load_model(args.weights)
+
+	print(f"Created {model.meta.name} model with weights from \"{args.weights}\"")
+
+	train_ds, val_ds = ds.load_datasets(
+		root=args.dataset,
+		model_input_size=model.meta.input_size,
+		prepare=model.prepare,
+		split_id=args.split_id)
+
+	print(f"Found {len(train_ds)} training and {len(val_ds)} validation images")
+
+	# we ignore the training dataset here,
+	# since we are only interested in evaluation
+	val_it = ds.new_iterator(val_ds,
+							 n_jobs=args.n_jobs,
+							 batch_size=args.batch_size,
+							 repeat=False,
+							 shuffle=False
+							)
+
+	accu = evaluate(model, val_it, device_id=args.device_id)
+	print(f"Accuracy: {accu:.2%}")
+
+chainer.config.cv_resize_backend = "cv2"
+chainer.config.train = False
+chainer.config.enable_backprop = False
+
+main(parser.parse_args())

+ 0 - 0
moth_classifier/__init__.py


+ 98 - 0
moth_classifier/dataset.py

@@ -0,0 +1,98 @@
+import numpy as np
+
+from chainer import iterators
+from chainer.dataset import DatasetMixin
+from chainer.datasets import TransformDataset
+from chainercv import transforms as tr
+from imageio import imread
+from pathlib import Path
+from typing import Callable
+
+class Dataset(DatasetMixin):
+
+    def __init__(self, root: str, split_id: int, is_train: bool = True):
+        super().__init__()
+
+        root = Path(root)
+        self._root = root
+        self.class_names = np.loadtxt(root / "class_names.txt", dtype="U255")
+
+        # read annoations from the root folder
+        _images = np.loadtxt(root / "images.txt", dtype=[("id", np.int32), ("fname", "U255")])
+        _labels = np.loadtxt(root / "labels.txt", dtype=np.int32)
+        _split_ids = np.loadtxt(root / "tr_ID.txt", dtype=np.int32)
+
+        if is_train:
+            # select all other splits
+            split_mask = _split_ids != split_id
+
+        else:
+            # select only images for a given split ID
+            split_mask = _split_ids == split_id
+
+        self.images = _images["fname"][split_mask]
+        self.labels = _labels[split_mask]
+
+
+    def __len__(self):
+        return len(self.images)
+
+    def get_example(self, i):
+        """ Here the images are loaded """
+        im_path = self._root / "images" / self.images[i]
+        label = self.labels[i]
+        return imread(im_path, pilmode="RGB"), label
+
+class DataTransformer(object):
+
+    def __init__(self, prepare: Callable, size: int):
+        super().__init__()
+        self.prepare = prepare
+        self.size = size
+
+    def __call__(self, data):
+        """
+            Before passing the data to the CNN, it needs
+            to be transformed:
+                - resize with the "prepare" function of the model
+                - center crop to the size of the CNN input
+                - rescale the pixel range from [0..1] tp [-1..1]
+                  (the CNN was trained with pixel range)
+        """
+        image, label = data
+        new_image = self.prepare(image, self.size)
+
+        new_image = tr.center_crop(new_image, size=(self.size, self.size))
+
+        # transform the pixel range from 0..1 to -1..1
+        new_image = new_image * 2 - 1
+        return new_image, label
+
+def load_datasets(root: Path, model_input_size: int, prepare: Callable, split_id: int):
+    """
+        load the two dataset splits (training and evaluation)
+        and return these as Dataset instances
+    """
+    train_ds = Dataset(root, split_id=split_id, is_train=True)
+    val_ds = Dataset(root, split_id=split_id, is_train=False)
+
+    transformer = DataTransformer(prepare, model_input_size)
+    train_ds = TransformDataset(train_ds, transformer)
+    val_ds = TransformDataset(val_ds, transformer)
+
+    return train_ds, val_ds
+
+def new_iterator(dataset, n_jobs: int = -1,  **kwargs):
+    """
+        Depending on the n_jobs argument create either a single-thread
+        serial iterator (n_jobs < 1) or a multi-thread iterator.
+        Iterators are responsible to gather the images from the dataset
+        and group it to a batch.
+    """
+    it_cls = iterators.SerialIterator
+
+    if n_jobs >= 1:
+        kwargs["n_threads"] = n_jobs
+        it_cls = iterators.MultithreadIterator
+
+    return it_cls(dataset, **kwargs)

+ 20 - 0
moth_classifier/parser.py

@@ -0,0 +1,20 @@
+from cvargparse import Arg
+from cvargparse import BaseParser
+
+
+def parse_args():
+	parser = BaseParser()
+
+	parser.add_args([
+		Arg("--dataset", "-ds", default="data/eu_moths"),
+		Arg("--weights", "-w", default="data/clf_final.npz"),
+
+		Arg("--split_id", "-split", type=int, default=0),
+		Arg("--batch_size", "-bs", type=int, default=32),
+		Arg("--device_id", "-g", type=int, default=0),
+		Arg("--n_jobs", "-j", type=int, default=4),
+
+		Arg("--debug", action="store_true"),
+	])
+
+	return parser.parse_args()