autoencoder_experiments.ipynb 9.7 KB

from torch import nn
from torchinfo import summary
/home/kleinsteuber/anaconda3/envs/pytorch-gpu/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
class Autoencoder(nn.Module):
    def __init__(self, dropout=0.1, latent_features=512):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=2),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.Flatten(),
            nn.Linear(1024, latent_features),
            nn.ReLU(True),
        )
        self.decoder = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Unflatten(1, (64, 4, 4)),

            nn.Dropout(dropout),
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.ConvTranspose2d(64, 64, kernel_size=6, stride=2, padding=2),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.ConvTranspose2d(64, 64, kernel_size=8, stride=2, padding=3),
            nn.ReLU(True),

            nn.Dropout(dropout),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding="same"),
            nn.Tanh(),
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
summary(Autoencoder(), (32, 3, 256, 256))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Autoencoder                              [32, 3, 256, 256]         --
├─Sequential: 1-1                        [32, 512]                 --
│    └─Dropout: 2-1                      [32, 3, 256, 256]         --
│    └─Conv2d: 2-2                       [32, 32, 128, 128]        4,736
│    └─ReLU: 2-3                         [32, 32, 128, 128]        --
│    └─Dropout: 2-4                      [32, 32, 128, 128]        --
│    └─Conv2d: 2-5                       [32, 64, 64, 64]          51,264
│    └─ReLU: 2-6                         [32, 64, 64, 64]          --
│    └─Dropout: 2-7                      [32, 64, 64, 64]          --
│    └─Conv2d: 2-8                       [32, 64, 32, 32]          36,928
│    └─ReLU: 2-9                         [32, 64, 32, 32]          --
│    └─Dropout: 2-10                     [32, 64, 32, 32]          --
│    └─Conv2d: 2-11                      [32, 64, 16, 16]          36,928
│    └─ReLU: 2-12                        [32, 64, 16, 16]          --
│    └─Dropout: 2-13                     [32, 64, 16, 16]          --
│    └─Conv2d: 2-14                      [32, 128, 8, 8]           73,856
│    └─ReLU: 2-15                        [32, 128, 8, 8]           --
│    └─Dropout: 2-16                     [32, 128, 8, 8]           --
│    └─Conv2d: 2-17                      [32, 64, 4, 4]            73,792
│    └─ReLU: 2-18                        [32, 64, 4, 4]            --
│    └─Dropout: 2-19                     [32, 64, 4, 4]            --
│    └─Flatten: 2-20                     [32, 1024]                --
│    └─Linear: 2-21                      [32, 512]                 524,800
│    └─ReLU: 2-22                        [32, 512]                 --
├─Sequential: 1-2                        [32, 3, 256, 256]         --
│    └─Linear: 2-23                      [32, 1024]                525,312
│    └─ReLU: 2-24                        [32, 1024]                --
│    └─Unflatten: 2-25                   [32, 64, 4, 4]            --
│    └─Dropout: 2-26                     [32, 64, 4, 4]            --
│    └─ConvTranspose2d: 2-27             [32, 128, 8, 8]           131,200
│    └─ReLU: 2-28                        [32, 128, 8, 8]           --
│    └─Dropout: 2-29                     [32, 128, 8, 8]           --
│    └─ConvTranspose2d: 2-30             [32, 64, 16, 16]          131,136
│    └─ReLU: 2-31                        [32, 64, 16, 16]          --
│    └─Dropout: 2-32                     [32, 64, 16, 16]          --
│    └─ConvTranspose2d: 2-33             [32, 64, 32, 32]          65,600
│    └─ReLU: 2-34                        [32, 64, 32, 32]          --
│    └─Dropout: 2-35                     [32, 64, 32, 32]          --
│    └─ConvTranspose2d: 2-36             [32, 64, 64, 64]          65,600
│    └─ReLU: 2-37                        [32, 64, 64, 64]          --
│    └─Dropout: 2-38                     [32, 64, 64, 64]          --
│    └─ConvTranspose2d: 2-39             [32, 32, 128, 128]        73,760
│    └─ReLU: 2-40                        [32, 32, 128, 128]        --
│    └─Dropout: 2-41                     [32, 32, 128, 128]        --
│    └─ConvTranspose2d: 2-42             [32, 16, 256, 256]        32,784
│    └─ReLU: 2-43                        [32, 16, 256, 256]        --
│    └─Dropout: 2-44                     [32, 16, 256, 256]        --
│    └─Conv2d: 2-45                      [32, 3, 256, 256]         435
│    └─Tanh: 2-46                        [32, 3, 256, 256]         --
==========================================================================================
Total params: 1,828,131
Trainable params: 1,828,131
Non-trainable params: 0
Total mult-adds (G): 131.37
==========================================================================================
Input size (MB): 25.17
Forward/backward pass size (MB): 768.21
Params size (MB): 7.31
Estimated Total Size (MB): 800.69
==========================================================================================