unit_tests_utils.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from os.path import expanduser
  2. import os
  3. import torch
  4. from sklearn.datasets import make_classification
  5. from sklearn.model_selection import train_test_split
  6. from torch.utils.data import TensorDataset
  7. from torch.utils.data.dataloader import DataLoader
  8. from torchvision.datasets import MNIST
  9. from torchvision.transforms import Compose, ToTensor
  10. from avalanche.benchmarks import nc_benchmark
  11. def common_setups():
  12. # adapt_dataset_urls()
  13. pass
  14. def load_benchmark(use_task_labels=False, fast_test=True):
  15. """
  16. Returns a NC Benchmark from a fake dataset of 10 classes, 5 experiences,
  17. 2 classes per experience.
  18. """
  19. if fast_test:
  20. my_nc_benchmark = get_fast_benchmark(use_task_labels)
  21. else:
  22. mnist_train = MNIST(
  23. root=expanduser("~") + "/.avalanche/data/mnist/",
  24. train=True, download=True,
  25. transform=Compose([ToTensor()]))
  26. mnist_test = MNIST(
  27. root=expanduser("~") + "/.avalanche/data/mnist/",
  28. train=False, download=True,
  29. transform=Compose([ToTensor()]))
  30. my_nc_benchmark = nc_benchmark(
  31. mnist_train, mnist_test, 5,
  32. task_labels=use_task_labels, seed=1234)
  33. return my_nc_benchmark
  34. def get_fast_benchmark(use_task_labels=False, shuffle=True):
  35. n_samples_per_class = 100
  36. dataset = make_classification(
  37. n_samples=10 * n_samples_per_class,
  38. n_classes=10,
  39. n_features=6, n_informative=6, n_redundant=0)
  40. X = torch.from_numpy(dataset[0]).float()
  41. y = torch.from_numpy(dataset[1]).long()
  42. train_X, test_X, train_y, test_y = train_test_split(
  43. X, y, train_size=0.6, shuffle=True, stratify=y)
  44. train_dataset = TensorDataset(train_X, train_y)
  45. test_dataset = TensorDataset(test_X, test_y)
  46. my_nc_benchmark = nc_benchmark(train_dataset, test_dataset, 5,
  47. task_labels=use_task_labels, shuffle=shuffle)
  48. return my_nc_benchmark
  49. def load_experience_train_eval(experience, batch_size=32, num_workers=0):
  50. for x, y, t in DataLoader(experience.dataset.train(), batch_size=batch_size,
  51. num_workers=num_workers):
  52. break
  53. for x, y, t in DataLoader(experience.dataset.eval(), batch_size=batch_size,
  54. num_workers=num_workers):
  55. break
  56. def get_device():
  57. if "USE_GPU" in os.environ:
  58. use_gpu = os.environ['USE_GPU'].lower() in ["true"]
  59. else:
  60. use_gpu = False
  61. print("Test on GPU:", use_gpu)
  62. if use_gpu:
  63. device = "cuda"
  64. else:
  65. device = "cpu"
  66. return device
  67. __all__ = [
  68. 'common_setups',
  69. 'load_benchmark',
  70. 'get_fast_benchmark',
  71. 'load_experience_train_eval',
  72. 'get_device'
  73. ]