#!/usr/bin/env python if __name__ != '__main__': raise Exception("Do not import me!") import os os.environ["OMP_NUM_THREADS"] = "2" import chainer import cupy import numpy as np from chainer.backends import cuda from chainer.dataset import concat_examples from chainer.dataset.iterator import Iterator from cvargparse import Arg from cvargparse import BaseParser from cvmodelz.models import ModelFactory from tqdm.auto import tqdm from moth_classifier import dataset as ds from moth_classifier import parser def evaluate(model: chainer.Chain, iterator: Iterator, device_id: int = -1) -> float: if device_id >= 0: device = cuda.get_device_from_id(device_id) device.use() else: # for CPU mode device = device_id model.to_device(device) n_batches = int(np.ceil(len(iterator.dataset) / iterator.batch_size)) preds = [] labels = [] iterator.reset() for batch in tqdm(iterator, total=n_batches): X, y = concat_examples(batch, device=device) # shape is (batch_size, #classes) logits = model(X) logits.to_cpu() # get the class ID with the highest score preds.extend(logits.array.argmax(axis=-1)) labels.extend(chainer.cuda.to_cpu(y)) return np.mean(np.array(preds) == np.array(labels)) def load_model(weights: str): model = ModelFactory.new("cvmodelz.InceptionV3", n_classes=200) model.load(weights, path="model/") return model def main(args): if args.debug: chainer.set_debug(args.debug) print("=== Chainer info ===") chainer.config.show() print("=== CuPy info ===") cupy.show_config() model = load_model(args.weights) print(f"Created {model.meta.name} model with weights from \"{args.weights}\"") train_ds, val_ds = ds.load_datasets( root=args.dataset, model_input_size=model.meta.input_size, prepare=model.prepare, split_id=args.split_id) print(f"Found {len(train_ds)} training and {len(val_ds)} validation images") # we ignore the training dataset here, # since we are only interested in evaluation val_it = ds.new_iterator(val_ds, n_jobs=args.n_jobs, batch_size=args.batch_size, repeat=False, shuffle=False ) accu = evaluate(model, val_it, device_id=args.device_id) print(f"Accuracy: {accu:.2%}") chainer.config.cv_resize_backend = "cv2" chainer.config.train = False chainer.config.enable_backprop = False main(parser.parse_args())