{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/kleinsteuber/anaconda3/envs/pytorch-gpu/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from torch import nn\n", "from torchinfo import summary" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class Autoencoder(nn.Module):\n", " def __init__(self, dropout=0.1, latent_features=512):\n", " super(Autoencoder, self).__init__()\n", " self.encoder = nn.Sequential(\n", " nn.Dropout(dropout),\n", " nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=2),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.Flatten(),\n", " nn.Linear(1024, latent_features),\n", " nn.ReLU(True),\n", " )\n", " self.decoder = nn.Sequential(\n", " nn.Linear(512, 1024),\n", " nn.ReLU(True),\n", " nn.Unflatten(1, (64, 4, 4)),\n", "\n", " nn.Dropout(dropout),\n", " nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.ConvTranspose2d(64, 64, kernel_size=6, stride=2, padding=2),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.ConvTranspose2d(64, 64, kernel_size=8, stride=2, padding=3),\n", " nn.ReLU(True),\n", "\n", " nn.Dropout(dropout),\n", " nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=\"same\"),\n", " nn.Tanh(),\n", " )\n", " \n", " def forward(self, x):\n", " x = self.encoder(x)\n", " x = self.decoder(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "Autoencoder [32, 3, 256, 256] --\n", "├─Sequential: 1-1 [32, 512] --\n", "│ └─Dropout: 2-1 [32, 3, 256, 256] --\n", "│ └─Conv2d: 2-2 [32, 32, 128, 128] 4,736\n", "│ └─ReLU: 2-3 [32, 32, 128, 128] --\n", "│ └─Dropout: 2-4 [32, 32, 128, 128] --\n", "│ └─Conv2d: 2-5 [32, 64, 64, 64] 51,264\n", "│ └─ReLU: 2-6 [32, 64, 64, 64] --\n", "│ └─Dropout: 2-7 [32, 64, 64, 64] --\n", "│ └─Conv2d: 2-8 [32, 64, 32, 32] 36,928\n", "│ └─ReLU: 2-9 [32, 64, 32, 32] --\n", "│ └─Dropout: 2-10 [32, 64, 32, 32] --\n", "│ └─Conv2d: 2-11 [32, 64, 16, 16] 36,928\n", "│ └─ReLU: 2-12 [32, 64, 16, 16] --\n", "│ └─Dropout: 2-13 [32, 64, 16, 16] --\n", "│ └─Conv2d: 2-14 [32, 128, 8, 8] 73,856\n", "│ └─ReLU: 2-15 [32, 128, 8, 8] --\n", "│ └─Dropout: 2-16 [32, 128, 8, 8] --\n", "│ └─Conv2d: 2-17 [32, 64, 4, 4] 73,792\n", "│ └─ReLU: 2-18 [32, 64, 4, 4] --\n", "│ └─Dropout: 2-19 [32, 64, 4, 4] --\n", "│ └─Flatten: 2-20 [32, 1024] --\n", "│ └─Linear: 2-21 [32, 512] 524,800\n", "│ └─ReLU: 2-22 [32, 512] --\n", "├─Sequential: 1-2 [32, 3, 256, 256] --\n", "│ └─Linear: 2-23 [32, 1024] 525,312\n", "│ └─ReLU: 2-24 [32, 1024] --\n", "│ └─Unflatten: 2-25 [32, 64, 4, 4] --\n", "│ └─Dropout: 2-26 [32, 64, 4, 4] --\n", "│ └─ConvTranspose2d: 2-27 [32, 128, 8, 8] 131,200\n", "│ └─ReLU: 2-28 [32, 128, 8, 8] --\n", "│ └─Dropout: 2-29 [32, 128, 8, 8] --\n", "│ └─ConvTranspose2d: 2-30 [32, 64, 16, 16] 131,136\n", "│ └─ReLU: 2-31 [32, 64, 16, 16] --\n", "│ └─Dropout: 2-32 [32, 64, 16, 16] --\n", "│ └─ConvTranspose2d: 2-33 [32, 64, 32, 32] 65,600\n", "│ └─ReLU: 2-34 [32, 64, 32, 32] --\n", "│ └─Dropout: 2-35 [32, 64, 32, 32] --\n", "│ └─ConvTranspose2d: 2-36 [32, 64, 64, 64] 65,600\n", "│ └─ReLU: 2-37 [32, 64, 64, 64] --\n", "│ └─Dropout: 2-38 [32, 64, 64, 64] --\n", "│ └─ConvTranspose2d: 2-39 [32, 32, 128, 128] 73,760\n", "│ └─ReLU: 2-40 [32, 32, 128, 128] --\n", "│ └─Dropout: 2-41 [32, 32, 128, 128] --\n", "│ └─ConvTranspose2d: 2-42 [32, 16, 256, 256] 32,784\n", "│ └─ReLU: 2-43 [32, 16, 256, 256] --\n", "│ └─Dropout: 2-44 [32, 16, 256, 256] --\n", "│ └─Conv2d: 2-45 [32, 3, 256, 256] 435\n", "│ └─Tanh: 2-46 [32, 3, 256, 256] --\n", "==========================================================================================\n", "Total params: 1,828,131\n", "Trainable params: 1,828,131\n", "Non-trainable params: 0\n", "Total mult-adds (G): 131.37\n", "==========================================================================================\n", "Input size (MB): 25.17\n", "Forward/backward pass size (MB): 768.21\n", "Params size (MB): 7.31\n", "Estimated Total Size (MB): 800.69\n", "==========================================================================================" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "summary(Autoencoder(), (32, 3, 256, 256))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.4 ('pytorch-gpu')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "17cd5c528a3345b75540c61f907eece919c031d57a2ca1e5653325af249173c9" } } }, "nbformat": 4, "nbformat_minor": 2 }