caffe_load_network.m 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. function [net, mean_data] = caffe_load_network ( s_pathtodeployfile, s_pathtomodel, s_phase, s_meanfile, b_reshape_for_single_image_processing)
  2. % function [net, mean_data] = caffe_load_network ( s_pathtodeployfile, s_pathtomodel, s_phase, s_meanfile, b_reshape_for_single_image_processing)
  3. %
  4. % BRIEF
  5. % Load a specified network with mean image for train or test.
  6. % Requires Caffe version from 17-07-2015 (hash: 6d92d8fcfe0eea9495ffbc)
  7. %
  8. % INPUT
  9. % s_pathtodeployfile -- string, path to file describing the network's
  10. % architecture, e.g., deploy.prototxt
  11. % s_pathtomodel -- string, path to file describing the network's
  12. % learned parameters, e.g. bvlc_reference_caffenet.caffemodel
  13. % s_phase -- string, 'test' or 'train'. 'test' will de-activate dropout
  14. % s_meanfile -- string, path to file which contains the mean
  15. % image of a dataset, e.g., imagenet_mean.binaryproto
  16. % b_reshape_for_single_image_processing
  17. % -- bool, true of the network shall
  18. % operate on single images instead of image batches
  19. %% load network
  20. net = caffe.Net(s_pathtodeployfile, s_pathtomodel, s_phase); % create net and load weights
  21. %% reshape network input for batch or single-image processing
  22. if ( b_reshape_for_single_image_processing )
  23. %
  24. % reshape network to run only a single image in a forward pass instead
  25. % of 10 (for which it was trained)
  26. i_currentNetInputSize = net.blobs('data').shape;
  27. net.blobs('data').reshape([i_currentNetInputSize(1) ...
  28. i_currentNetInputSize(2) ...
  29. i_currentNetInputSize(3) ...
  30. 1 ...
  31. ]); % reshape blob 'data'
  32. net.reshape();
  33. end
  34. %% load mean image and adapt to network input size
  35. mean_data = caffe.io.read_mean( s_meanfile );
  36. % crop center from mean file according to network size
  37. i_sizeMean = size( mean_data );
  38. offset_row = floor ( int32( i_sizeMean(1)-i_currentNetInputSize(1) ) / 2 ) + 1;
  39. offset_col = floor ( int32( i_sizeMean(2)-i_currentNetInputSize(2) ) / 2 ) + 1;
  40. mean_data = mean_data( offset_row:offset_row+i_currentNetInputSize(1)-1, ...
  41. offset_col:offset_col+i_currentNetInputSize(2)-1, ...
  42. : ...
  43. );
  44. end