|
@@ -15,9 +15,10 @@ from torchvision.utils import save_image
|
|
|
from torchinfo import summary
|
|
|
|
|
|
from py.PyTorchData import create_dataloader, model_output_to_image
|
|
|
+from py.Dataset import Dataset
|
|
|
from py.Autoencoder2 import Autoencoder
|
|
|
|
|
|
-def train_autoencoder(model: Autoencoder, train_dataloader: DataLoader, name: str, device: str = "cpu", num_epochs=100, criterion = nn.MSELoss(), lr: float = 1e-3, weight_decay: float = 1e-5, noise: bool = False, sparse: bool = False, reg_rate: float = 1e-4):
|
|
|
+def train_autoencoder(model: Autoencoder, train_dataloader: DataLoader, name: str, device: str = "cpu", num_epochs=100, criterion = nn.MSELoss(), lr: float = 1e-3, weight_decay: float = 1e-5, noise: bool = False, sparse: bool = False, reg_rate: float = 1e-4, noise_var: float = 0.015):
|
|
|
model = model.to(device)
|
|
|
print(f"Using {device} device")
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
@@ -34,7 +35,7 @@ def train_autoencoder(model: Autoencoder, train_dataloader: DataLoader, name: st
|
|
|
img = Variable(img).to(device)
|
|
|
input = img
|
|
|
if noise:
|
|
|
- input = input + (0.015 ** 0.5) * torch.randn(img.size(), device=device)
|
|
|
+ input = input + (noise_var ** 0.5) * torch.randn(img.size(), device=device)
|
|
|
# ===================forward=====================
|
|
|
latent = model.encoder(input)
|
|
|
output = model.decoder(latent)
|
|
@@ -59,7 +60,7 @@ def train_autoencoder(model: Autoencoder, train_dataloader: DataLoader, name: st
|
|
|
f.write(f"{dsp_epoch},{total_loss},{total_reg_loss}\n")
|
|
|
|
|
|
# output image
|
|
|
- if epoch % 2 == 0:
|
|
|
+ if epoch % 10 == 0:
|
|
|
pic = model_output_to_image(output.cpu().data)
|
|
|
save_image(pic, f"./ae_train_NoBackup/{name}/image_{dsp_epoch:03d}.png")
|
|
|
|
|
@@ -73,7 +74,8 @@ def train_autoencoder(model: Autoencoder, train_dataloader: DataLoader, name: st
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="Autoencoder train script")
|
|
|
parser.add_argument("name", type=str, help="Name of the training session (name of the save folder)")
|
|
|
- parser.add_argument("img_folder", type=str, help="Path to directory containing train images (may contain subfolders)")
|
|
|
+ parser.add_argument("dataset_folder", type=str, help="Path to dataset folder containing sessions")
|
|
|
+ parser.add_argument("session", type=str, help="Session name")
|
|
|
parser.add_argument("--device", type=str, help="PyTorch device to train on (cpu or cuda)", default="cpu")
|
|
|
parser.add_argument("--epochs", type=int, help="Number of epochs", default=100)
|
|
|
parser.add_argument("--batch_size", type=int, help="Batch size (>=1)", default=32)
|
|
@@ -83,20 +85,24 @@ if __name__ == "__main__":
|
|
|
parser.add_argument("--latent", type=int, help="Number of latent features", default=512)
|
|
|
parser.add_argument("--image_transforms", action="store_true", help="Truncate and resize images (only enable if the input images have not been truncated resized to the target size already)")
|
|
|
parser.add_argument("--noise", action="store_true", help="Add Gaussian noise to model input")
|
|
|
+ parser.add_argument("--noise_var", type=float, help="Noise variance", default=0.015)
|
|
|
parser.add_argument("--sparse", action="store_true", help="Add L1 penalty to latent features")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
+ ds = Dataset(args.dataset_folder)
|
|
|
+ session = ds.create_session(args.session)
|
|
|
+
|
|
|
if args.image_transforms:
|
|
|
print("Image transforms enabled: Images will be truncated and resized.")
|
|
|
else:
|
|
|
print("Image transforms disabled: Images are expected to be of the right size.")
|
|
|
|
|
|
- torch.manual_seed(10810)
|
|
|
- np.random.seed(10810)
|
|
|
- random.seed(10810)
|
|
|
+ # torch.manual_seed(10810)
|
|
|
+ # np.random.seed(10810)
|
|
|
+ # random.seed(10810)
|
|
|
|
|
|
- data_loader = create_dataloader(args.img_folder, batch_size=args.batch_size, skip_transforms=not args.image_transforms)
|
|
|
+ data_loader = create_dataloader(session.get_lapse_folder(), batch_size=args.batch_size, skip_transforms=not args.image_transforms)
|
|
|
model = Autoencoder(dropout=args.dropout, latent_features=args.latent)
|
|
|
print("Model:")
|
|
|
summary(model, (args.batch_size, 3, 256, 256))
|
|
@@ -108,4 +114,4 @@ if __name__ == "__main__":
|
|
|
print("Adding Gaussian noise to model input")
|
|
|
if args.sparse:
|
|
|
print("Adding L1 penalty to latent features (sparse)")
|
|
|
- train_autoencoder(model, data_loader, args.name, device=args.device, num_epochs=args.epochs, lr=args.lr, noise=args.noise, sparse=args.sparse, reg_rate=args.reg_rate)
|
|
|
+ train_autoencoder(model, data_loader, args.name, device=args.device, num_epochs=args.epochs, lr=args.lr, noise=args.noise, sparse=args.sparse, reg_rate=args.reg_rate, noise_var=args.noise_var)
|