from torch import nn
from torchinfo import summary
/home/kleinsteuber/anaconda3/envs/pytorch-gpu/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
class Autoencoder(nn.Module):
def __init__(self, dropout=0.1, latent_features=512):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Dropout(dropout),
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.ReLU(True),
nn.Dropout(dropout),
nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=2),
nn.ReLU(True),
nn.Dropout(dropout),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.Dropout(dropout),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.Dropout(dropout),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.Dropout(dropout),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(True),
nn.Dropout(dropout),
nn.Flatten(),
nn.Linear(1024, latent_features),
nn.ReLU(True),
)
self.decoder = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Unflatten(1, (64, 4, 4)),
nn.Dropout(dropout),
nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(True),
nn.Dropout(dropout),
nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(True),
nn.Dropout(dropout),
nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(True),
nn.Dropout(dropout),
nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(True),
nn.Dropout(dropout),
nn.ConvTranspose2d(64, 64, kernel_size=6, stride=2, padding=2),
nn.ReLU(True),
nn.Dropout(dropout),
nn.ConvTranspose2d(64, 64, kernel_size=8, stride=2, padding=3),
nn.ReLU(True),
nn.Dropout(dropout),
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding="same"),
nn.Tanh(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
summary(Autoencoder(), (32, 3, 256, 256))
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== Autoencoder [32, 3, 256, 256] -- ├─Sequential: 1-1 [32, 512] -- │ └─Dropout: 2-1 [32, 3, 256, 256] -- │ └─Conv2d: 2-2 [32, 32, 128, 128] 4,736 │ └─ReLU: 2-3 [32, 32, 128, 128] -- │ └─Dropout: 2-4 [32, 32, 128, 128] -- │ └─Conv2d: 2-5 [32, 64, 64, 64] 51,264 │ └─ReLU: 2-6 [32, 64, 64, 64] -- │ └─Dropout: 2-7 [32, 64, 64, 64] -- │ └─Conv2d: 2-8 [32, 64, 32, 32] 36,928 │ └─ReLU: 2-9 [32, 64, 32, 32] -- │ └─Dropout: 2-10 [32, 64, 32, 32] -- │ └─Conv2d: 2-11 [32, 64, 16, 16] 36,928 │ └─ReLU: 2-12 [32, 64, 16, 16] -- │ └─Dropout: 2-13 [32, 64, 16, 16] -- │ └─Conv2d: 2-14 [32, 128, 8, 8] 73,856 │ └─ReLU: 2-15 [32, 128, 8, 8] -- │ └─Dropout: 2-16 [32, 128, 8, 8] -- │ └─Conv2d: 2-17 [32, 64, 4, 4] 73,792 │ └─ReLU: 2-18 [32, 64, 4, 4] -- │ └─Dropout: 2-19 [32, 64, 4, 4] -- │ └─Flatten: 2-20 [32, 1024] -- │ └─Linear: 2-21 [32, 512] 524,800 │ └─ReLU: 2-22 [32, 512] -- ├─Sequential: 1-2 [32, 3, 256, 256] -- │ └─Linear: 2-23 [32, 1024] 525,312 │ └─ReLU: 2-24 [32, 1024] -- │ └─Unflatten: 2-25 [32, 64, 4, 4] -- │ └─Dropout: 2-26 [32, 64, 4, 4] -- │ └─ConvTranspose2d: 2-27 [32, 128, 8, 8] 131,200 │ └─ReLU: 2-28 [32, 128, 8, 8] -- │ └─Dropout: 2-29 [32, 128, 8, 8] -- │ └─ConvTranspose2d: 2-30 [32, 64, 16, 16] 131,136 │ └─ReLU: 2-31 [32, 64, 16, 16] -- │ └─Dropout: 2-32 [32, 64, 16, 16] -- │ └─ConvTranspose2d: 2-33 [32, 64, 32, 32] 65,600 │ └─ReLU: 2-34 [32, 64, 32, 32] -- │ └─Dropout: 2-35 [32, 64, 32, 32] -- │ └─ConvTranspose2d: 2-36 [32, 64, 64, 64] 65,600 │ └─ReLU: 2-37 [32, 64, 64, 64] -- │ └─Dropout: 2-38 [32, 64, 64, 64] -- │ └─ConvTranspose2d: 2-39 [32, 32, 128, 128] 73,760 │ └─ReLU: 2-40 [32, 32, 128, 128] -- │ └─Dropout: 2-41 [32, 32, 128, 128] -- │ └─ConvTranspose2d: 2-42 [32, 16, 256, 256] 32,784 │ └─ReLU: 2-43 [32, 16, 256, 256] -- │ └─Dropout: 2-44 [32, 16, 256, 256] -- │ └─Conv2d: 2-45 [32, 3, 256, 256] 435 │ └─Tanh: 2-46 [32, 3, 256, 256] -- ========================================================================================== Total params: 1,828,131 Trainable params: 1,828,131 Non-trainable params: 0 Total mult-adds (G): 131.37 ========================================================================================== Input size (MB): 25.17 Forward/backward pass size (MB): 768.21 Params size (MB): 7.31 Estimated Total Size (MB): 800.69 ==========================================================================================