datasets_from_filelists.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 21-06-2020 #
  7. # Author(s): Lorenzo Pellegrini, Vincenzo Lomonaco #
  8. # E-mail: contact@continualai.org #
  9. # Website: continualai.org #
  10. ################################################################################
  11. """ This module contains useful utility functions and classes to generate
  12. pytorch datasets based on filelists (Caffe style) """
  13. from pathlib import Path
  14. from typing import Tuple, Sequence, Optional
  15. import torch.utils.data as data
  16. from PIL import Image
  17. import os
  18. import os.path
  19. from torch import Tensor
  20. from torchvision.transforms.functional import crop
  21. from avalanche.benchmarks.utils import AvalancheDataset
  22. def default_image_loader(path):
  23. """
  24. Sets the default image loader for the Pytorch Dataset.
  25. :param path: relative or absolute path of the file to load.
  26. :returns: Returns the image as a RGB PIL image.
  27. """
  28. return Image.open(path).convert('RGB')
  29. def default_flist_reader(flist):
  30. """
  31. This reader reads a filelist and return a list of paths.
  32. :param flist: path of the flislist to read. The flist format should be:
  33. impath label, impath label, ...(same to caffe's filelist)
  34. :returns: Returns a list of paths (the examples to be loaded).
  35. """
  36. imlist = []
  37. with open(flist, 'r') as rf:
  38. for line in rf.readlines():
  39. impath, imlabel = line.strip().split()
  40. imlist.append((impath, int(imlabel)))
  41. return imlist
  42. class PathsDataset(data.Dataset):
  43. """
  44. This class extends the basic Pytorch Dataset class to handle list of paths
  45. as the main data source.
  46. """
  47. def __init__(
  48. self, root, files, transform=None, target_transform=None,
  49. loader=default_image_loader):
  50. """
  51. Creates a File Dataset from a list of files and labels.
  52. :param root: root path where the data to load are stored. May be None.
  53. :param files: list of tuples. Each tuple must contain two elements: the
  54. full path to the pattern and its class label. Optionally, the tuple
  55. may contain a third element describing the bounding box to use for
  56. cropping (top, left, height, width).
  57. :param transform: eventual transformation to add to the input data (x)
  58. :param target_transform: eventual transformation to add to the targets
  59. (y)
  60. :param loader: loader function to use (for the real data) given path.
  61. """
  62. if root is not None:
  63. root = Path(root)
  64. self.root: Optional[Path] = root
  65. self.imgs = files
  66. self.targets = [img_data[1] for img_data in self.imgs]
  67. self.paths = [img_data[0] for img_data in self.imgs]
  68. self.transform = transform
  69. self.target_transform = target_transform
  70. self.loader = loader
  71. def __getitem__(self, index):
  72. """
  73. Returns next element in the dataset given the current index.
  74. :param index: index of the data to get.
  75. :return: loaded item.
  76. """
  77. img_description = self.imgs[index]
  78. impath = img_description[0]
  79. target = img_description[1]
  80. bbox = None
  81. if len(img_description) > 2:
  82. bbox = img_description[2]
  83. if self.root is not None:
  84. impath = self.root / impath
  85. success = False
  86. while not success:
  87. try:
  88. img = Image.open(impath).convert('RGB')
  89. success = True
  90. except:
  91. print('image could not be loaded!')
  92. print(impath)
  93. continue
  94. # If a bounding box is provided, crop the image before passing it to
  95. # any user-defined transformation.
  96. if bbox is not None:
  97. if isinstance(bbox, Tensor):
  98. bbox = bbox.tolist()
  99. img = crop(img, *bbox)
  100. if self.transform is not None:
  101. img = self.transform(img)
  102. if self.target_transform is not None:
  103. target = self.target_transform(target)
  104. return img, target
  105. def __len__(self):
  106. """
  107. Returns the total number of elements in the dataset.
  108. :return: Total number of dataset items.
  109. """
  110. return len(self.imgs)
  111. class SeqPathsDataset(data.Dataset):
  112. """
  113. This class extends the basic Pytorch Dataset class to handle list of (path, target, seq_id) tupel
  114. as the main data source.
  115. """
  116. def __init__(
  117. self, root, files, transform=None, target_transform=None,
  118. loader=default_image_loader):
  119. """
  120. Creates a File Dataset from a list of files and labels.
  121. :param root: root path where the data to load are stored. May be None.
  122. :param files: list of tuples. Each tuple must contain two elements: the
  123. full path to the pattern and its class label. Optionally, the tuple
  124. may contain a third element describing the bounding box to use for
  125. cropping (top, left, height, width).
  126. :param transform: eventual transformation to add to the input data (x)
  127. :param target_transform: eventual transformation to add to the targets
  128. (y)
  129. :param loader: loader function to use (for the real data) given path.
  130. """
  131. if root is not None:
  132. root = Path(root)
  133. self.root= root
  134. self.imgs = files
  135. self.targets = [img_data[1] for img_data in self.imgs]
  136. self.paths = [img_data[0] for img_data in self.imgs]
  137. self.transform = transform
  138. self.target_transform = target_transform
  139. self.loader = loader
  140. def __getitem__(self, index):
  141. """
  142. Returns next element in the dataset given the current index.
  143. :param index: index of the data to get.
  144. :return: loaded item.
  145. """
  146. img_description = self.imgs[index]
  147. impath = img_description[0]
  148. target = img_description[1]
  149. seq_code = img_description[2]
  150. if self.root is not None:
  151. impath = self.root / impath
  152. success = False
  153. while not success:
  154. try:
  155. img = Image.open(impath).convert('RGB')
  156. success = True
  157. except:
  158. print('image could not be loaded!')
  159. print(impath)
  160. continue
  161. if self.transform is not None:
  162. img = self.transform(img)
  163. if self.target_transform is not None:
  164. target = self.target_transform(target)
  165. return img, target, str(impath), seq_code
  166. def __len__(self):
  167. """
  168. Returns the total number of elements in the dataset.
  169. :return: Total number of dataset items.
  170. """
  171. return len(self.imgs)
  172. class FilelistDataset(PathsDataset):
  173. """
  174. This class extends the basic Pytorch Dataset class to handle filelists as
  175. main data source.
  176. """
  177. def __init__(
  178. self, root, flist, transform=None, target_transform=None,
  179. flist_reader=default_flist_reader, loader=default_image_loader):
  180. """
  181. This reader reads a filelist and return a list of paths.
  182. :param root: root path where the data to load are stored. May be None.
  183. :param flist: path of the flislist to read. The flist format should be:
  184. impath label\nimpath label\n ...(same to caffe's filelist).
  185. :param transform: eventual transformation to add to the input data (x).
  186. :param target_transform: eventual transformation to add to the targets
  187. (y).
  188. :param flist_reader: loader function to use (for the filelists) given
  189. path.
  190. :param loader: loader function to use (for the real data) given path.
  191. """
  192. flist = str(flist) # Manages Path objects
  193. files_and_labels = flist_reader(flist)
  194. super().__init__(root, files_and_labels, transform=transform,
  195. target_transform=target_transform, loader=loader)
  196. def datasets_from_filelists(root, train_filelists, test_filelists,
  197. complete_test_set_only=False,
  198. train_transform=None, train_target_transform=None,
  199. test_transform=None, test_target_transform=None):
  200. """
  201. This reader reads a list of Caffe-style filelists and returns the proper
  202. Dataset objects.
  203. A Caffe-style list is just a text file where, for each line, two elements
  204. are described: the path to the pattern (relative to the root parameter)
  205. and its class label. Those two elements are separated by a single white
  206. space.
  207. This method reads each file list and returns a separate
  208. dataset for each of them.
  209. Beware that the parameters must be **list of paths to Caffe-style
  210. filelists**. If you need to create a dataset given a list of
  211. **pattern paths**, use `datasets_from_paths` instead.
  212. :param root: root path where the data to load are stored. May be None.
  213. :param train_filelists: list of paths to train filelists. The flist format
  214. should be: impath label\\nimpath label\\n ...(same to Caffe's filelist).
  215. :param test_filelists: list of paths to test filelists. It can be also a
  216. single path when the datasets is the same for each batch.
  217. :param complete_test_set_only: if True, test_filelists must contain
  218. the path to a single filelist that will serve as the complete test set.
  219. Alternatively, test_filelists can be the path (str) to the complete test
  220. set filelist. If False, train_filelists and test_filelists must contain
  221. the same amount of filelists paths. Defaults to False.
  222. :param train_transform: The transformation to apply to training patterns.
  223. Defaults to None.
  224. :param train_target_transform: The transformation to apply to training
  225. patterns targets. Defaults to None.
  226. :param test_transform: The transformation to apply to test patterns.
  227. Defaults to None.
  228. :param test_target_transform: The transformation to apply to test
  229. patterns targets. Defaults to None.
  230. :return: list of tuples (train dataset, test dataset) for each train
  231. filelist in the list.
  232. """
  233. if complete_test_set_only:
  234. if not (isinstance(test_filelists, str) or
  235. isinstance(test_filelists, Path)):
  236. if len(test_filelists) > 1:
  237. raise ValueError(
  238. 'When complete_test_set_only is True, test_filelists must '
  239. 'be a str, Path or a list with a single element describing '
  240. 'the path to the complete test set.')
  241. else:
  242. test_filelists = test_filelists[0]
  243. else:
  244. test_filelists = [test_filelists]
  245. else:
  246. if len(test_filelists) != len(train_filelists):
  247. raise ValueError(
  248. 'When complete_test_set_only is False, test_filelists and '
  249. 'train_filelists must contain the same number of elements.')
  250. transform_groups = dict(train=(train_transform, train_target_transform),
  251. eval=(test_transform, test_target_transform))
  252. train_inc_datasets = \
  253. [AvalancheDataset(FilelistDataset(root, tr_flist),
  254. transform_groups=transform_groups,
  255. initial_transform_group='train')
  256. for tr_flist in train_filelists]
  257. test_inc_datasets = \
  258. [AvalancheDataset(FilelistDataset(root, te_flist),
  259. transform_groups=transform_groups,
  260. initial_transform_group='eval')
  261. for te_flist in test_filelists]
  262. return train_inc_datasets, test_inc_datasets
  263. def datasets_from_paths(
  264. train_list, test_list, complete_test_set_only=False,
  265. train_transform=None, train_target_transform=None,
  266. test_transform=None, test_target_transform=None):
  267. """
  268. This utility takes, for each dataset to generate, a list of tuples each
  269. containing two elements: the full path to the pattern and its class label.
  270. Optionally, the tuple may contain a third element describing the bounding
  271. box to use for cropping.
  272. This is equivalent to `datasets_from_filelists`, which description
  273. contains more details on the behaviour of this utility. The two utilities
  274. differ in which `datasets_from_filelists` accepts paths to Caffe-style
  275. filelists while this one is able to create the datasets from an in-memory
  276. list.
  277. Note: this utility may try to detect (and strip) the common root path of
  278. all patterns in order to save some RAM memory.
  279. :param train_list: list of lists. Each list must contain tuples of two
  280. elements: the full path to the pattern and its class label. Optionally,
  281. the tuple may contain a third element describing the bounding box to use
  282. for cropping (top, left, height, width).
  283. :param test_list: list of lists. Each list must contain tuples of two
  284. elements: the full path to the pattern and its class label. Optionally,
  285. the tuple may contain a third element describing the bounding box to use
  286. for cropping (top, left, height, width). It can be also a single list
  287. when the test dataset is the same for each experience.
  288. :param complete_test_set_only: if True, test_list must contain a single list
  289. that will serve as the complete test set. If False, train_list and
  290. test_list must describe the same amount of datasets. Defaults to False.
  291. :param train_transform: The transformation to apply to training patterns.
  292. Defaults to None.
  293. :param train_target_transform: The transformation to apply to training
  294. patterns targets. Defaults to None.
  295. :param test_transform: The transformation to apply to test patterns.
  296. Defaults to None.
  297. :param test_target_transform: The transformation to apply to test
  298. patterns targets. Defaults to None.
  299. :return: A list of tuples (train dataset, test dataset).
  300. """
  301. if complete_test_set_only:
  302. # Check if the single dataset was passed as [Tuple1, Tuple2, ...]
  303. # or as [[Tuple1, Tuple2, ...]]
  304. if not isinstance(test_list[0], Tuple):
  305. if len(test_list) > 1:
  306. raise ValueError(
  307. 'When complete_test_set_only is True, test_list must '
  308. 'be a single list of tuples or a nested list containing '
  309. 'a single lis of tuples')
  310. else:
  311. test_list = test_list[0]
  312. else:
  313. test_list = [test_list]
  314. else:
  315. if len(test_list) != len(train_list):
  316. raise ValueError(
  317. 'When complete_test_set_only is False, test_list and '
  318. 'train_list must contain the same number of elements.')
  319. transform_groups = dict(train=(train_transform, train_target_transform),
  320. eval=(test_transform, test_target_transform))
  321. common_root = None
  322. # Detect common root
  323. try:
  324. all_paths = [pattern_tuple[0] for exp_list in train_list
  325. for pattern_tuple in exp_list] + \
  326. [pattern_tuple[0] for exp_list in test_list
  327. for pattern_tuple in exp_list]
  328. common_root = os.path.commonpath(all_paths)
  329. except ValueError:
  330. # commonpath may throw a ValueError in different situations!
  331. # See the official documentation for more details
  332. pass
  333. if common_root is not None and len(common_root) > 0 and \
  334. common_root != '/':
  335. has_common_root = True
  336. common_root = str(common_root)
  337. else:
  338. has_common_root = False
  339. common_root = None
  340. if has_common_root:
  341. # print(f'Common root found: {common_root}!')
  342. # All paths have a common filesystem root
  343. # Remove it from all paths!
  344. single_path_case = False
  345. tr_list = list()
  346. te_list = list()
  347. for idx_exp_list in range(len(train_list)):
  348. if single_path_case:
  349. break
  350. st_list = list()
  351. for x in train_list[idx_exp_list]:
  352. rel = os.path.relpath(x[0], common_root)
  353. if len(rel) == 0 or rel == '.':
  354. # May happen if the dataset has a single path
  355. single_path_case = True
  356. break
  357. st_list.append((rel, *x[1:]))
  358. tr_list.append(st_list)
  359. for idx_exp_list in range(len(test_list)):
  360. if single_path_case:
  361. break
  362. st_list = list()
  363. for x in test_list[idx_exp_list]:
  364. rel = os.path.relpath(x[0], common_root)
  365. if len(rel) == 0 or rel == '.':
  366. # May happen if the dataset has a single path
  367. single_path_case = True
  368. break
  369. st_list.append((rel, *x[1:]))
  370. te_list.append(st_list)
  371. if not single_path_case:
  372. train_list = tr_list
  373. test_list = te_list
  374. else:
  375. has_common_root = False
  376. common_root = None
  377. train_inc_datasets = \
  378. [AvalancheDataset(PathsDataset(common_root, tr_flist),
  379. transform_groups=transform_groups,
  380. initial_transform_group='train')
  381. for tr_flist in train_list]
  382. test_inc_datasets = \
  383. [AvalancheDataset(PathsDataset(common_root, te_flist),
  384. transform_groups=transform_groups,
  385. initial_transform_group='eval')
  386. for te_flist in test_list]
  387. return train_inc_datasets, test_inc_datasets
  388. def common_paths_root(exp_list):
  389. common_root = None
  390. # Detect common root
  391. try:
  392. all_paths = [pattern_tuple[0] for pattern_tuple in exp_list]
  393. common_root = os.path.commonpath(all_paths)
  394. except ValueError:
  395. # commonpath may throw a ValueError in different situations!
  396. # See the official documentation for more details
  397. pass
  398. if common_root is not None and len(common_root) > 0 and \
  399. common_root != '/':
  400. has_common_root = True
  401. common_root = str(common_root)
  402. else:
  403. has_common_root = False
  404. common_root = None
  405. if has_common_root:
  406. # print(f'Common root found: {common_root}!')
  407. # All paths have a common filesystem root
  408. # Remove it from all paths!
  409. single_path_case = False
  410. exp_tuples = list()
  411. for x in exp_list:
  412. if single_path_case:
  413. break
  414. rel = os.path.relpath(x[0], common_root)
  415. if len(rel) == 0 or rel == '.':
  416. # May happen if the dataset has a single path
  417. single_path_case = True
  418. break
  419. exp_tuples.append((rel, *x[1:]))
  420. if not single_path_case:
  421. exp_list = exp_tuples
  422. else:
  423. common_root = None
  424. return common_root, exp_list
  425. __all__ = [
  426. 'default_image_loader',
  427. 'default_flist_reader',
  428. 'PathsDataset',
  429. 'SeqPathsDataset',
  430. 'FilelistDataset',
  431. 'datasets_from_filelists',
  432. 'datasets_from_paths',
  433. 'common_paths_root'
  434. ]