main.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #!/usr/bin/env python
  2. if __name__ != '__main__': raise Exception("Do not import me!")
  3. import os
  4. os.environ["OMP_NUM_THREADS"] = "2"
  5. import chainer
  6. import cupy
  7. import numpy as np
  8. from chainer.backends import cuda
  9. from chainer.dataset import concat_examples
  10. from chainer.dataset.iterator import Iterator
  11. from cvargparse import Arg
  12. from cvargparse import BaseParser
  13. from cvmodelz.models import ModelFactory
  14. from tqdm.auto import tqdm
  15. from moth_classifier import dataset as ds
  16. from moth_classifier import parser
  17. def evaluate(model: chainer.Chain, iterator: Iterator, device_id: int = -1) -> float:
  18. if device_id >= 0:
  19. device = cuda.get_device_from_id(device_id)
  20. device.use()
  21. else:
  22. # for CPU mode
  23. device = device_id
  24. model.to_device(device)
  25. n_batches = int(np.ceil(len(iterator.dataset) / iterator.batch_size))
  26. preds = []
  27. labels = []
  28. iterator.reset()
  29. for batch in tqdm(iterator, total=n_batches):
  30. X, y = concat_examples(batch, device=device)
  31. # shape is (batch_size, #classes)
  32. logits = model(X)
  33. logits.to_cpu()
  34. # get the class ID with the highest score
  35. preds.extend(logits.array.argmax(axis=-1))
  36. labels.extend(chainer.cuda.to_cpu(y))
  37. return np.mean(np.array(preds) == np.array(labels))
  38. def load_model(weights: str):
  39. model = ModelFactory.new("cvmodelz.InceptionV3", n_classes=200)
  40. model.load(weights, path="model/")
  41. return model
  42. def main(args):
  43. if args.debug:
  44. chainer.set_debug(args.debug)
  45. print("=== Chainer info ===")
  46. chainer.config.show()
  47. print("=== CuPy info ===")
  48. cupy.show_config()
  49. model = load_model(args.weights)
  50. print(f"Created {model.meta.name} model with weights from \"{args.weights}\"")
  51. train_ds, val_ds = ds.load_datasets(
  52. root=args.dataset,
  53. model_input_size=model.meta.input_size,
  54. prepare=model.prepare,
  55. split_id=args.split_id)
  56. print(f"Found {len(train_ds)} training and {len(val_ds)} validation images")
  57. # we ignore the training dataset here,
  58. # since we are only interested in evaluation
  59. val_it = ds.new_iterator(val_ds,
  60. n_jobs=args.n_jobs,
  61. batch_size=args.batch_size,
  62. repeat=False,
  63. shuffle=False
  64. )
  65. accu = evaluate(model, val_it, device_id=args.device_id)
  66. print(f"Accuracy: {accu:.2%}")
  67. chainer.config.cv_resize_backend = "cv2"
  68. chainer.config.train = False
  69. chainer.config.enable_backprop = False
  70. main(parser.parse_args())