|
@@ -1,218 +0,0 @@
|
|
|
-{
|
|
|
- "cells": [
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 1,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stderr",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "/home/kleinsteuber/anaconda3/envs/pytorch-gpu/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
|
- " from .autonotebook import tqdm as notebook_tqdm\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "from torch import nn\n",
|
|
|
- "from torchinfo import summary"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 5,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "class Autoencoder(nn.Module):\n",
|
|
|
- " def __init__(self, dropout=0.1, latent_features=512):\n",
|
|
|
- " super(Autoencoder, self).__init__()\n",
|
|
|
- " self.encoder = nn.Sequential(\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=2),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.Flatten(),\n",
|
|
|
- " nn.Linear(1024, latent_features),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- " )\n",
|
|
|
- " self.decoder = nn.Sequential(\n",
|
|
|
- " nn.Linear(512, 1024),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- " nn.Unflatten(1, (64, 4, 4)),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.ConvTranspose2d(64, 64, kernel_size=6, stride=2, padding=2),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.ConvTranspose2d(64, 64, kernel_size=8, stride=2, padding=3),\n",
|
|
|
- " nn.ReLU(True),\n",
|
|
|
- "\n",
|
|
|
- " nn.Dropout(dropout),\n",
|
|
|
- " nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=\"same\"),\n",
|
|
|
- " nn.Tanh(),\n",
|
|
|
- " )\n",
|
|
|
- " \n",
|
|
|
- " def forward(self, x):\n",
|
|
|
- " x = self.encoder(x)\n",
|
|
|
- " x = self.decoder(x)\n",
|
|
|
- " return x"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 7,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "==========================================================================================\n",
|
|
|
- "Layer (type:depth-idx) Output Shape Param #\n",
|
|
|
- "==========================================================================================\n",
|
|
|
- "Autoencoder [32, 3, 256, 256] --\n",
|
|
|
- "├─Sequential: 1-1 [32, 512] --\n",
|
|
|
- "│ └─Dropout: 2-1 [32, 3, 256, 256] --\n",
|
|
|
- "│ └─Conv2d: 2-2 [32, 32, 128, 128] 4,736\n",
|
|
|
- "│ └─ReLU: 2-3 [32, 32, 128, 128] --\n",
|
|
|
- "│ └─Dropout: 2-4 [32, 32, 128, 128] --\n",
|
|
|
- "│ └─Conv2d: 2-5 [32, 64, 64, 64] 51,264\n",
|
|
|
- "│ └─ReLU: 2-6 [32, 64, 64, 64] --\n",
|
|
|
- "│ └─Dropout: 2-7 [32, 64, 64, 64] --\n",
|
|
|
- "│ └─Conv2d: 2-8 [32, 64, 32, 32] 36,928\n",
|
|
|
- "│ └─ReLU: 2-9 [32, 64, 32, 32] --\n",
|
|
|
- "│ └─Dropout: 2-10 [32, 64, 32, 32] --\n",
|
|
|
- "│ └─Conv2d: 2-11 [32, 64, 16, 16] 36,928\n",
|
|
|
- "│ └─ReLU: 2-12 [32, 64, 16, 16] --\n",
|
|
|
- "│ └─Dropout: 2-13 [32, 64, 16, 16] --\n",
|
|
|
- "│ └─Conv2d: 2-14 [32, 128, 8, 8] 73,856\n",
|
|
|
- "│ └─ReLU: 2-15 [32, 128, 8, 8] --\n",
|
|
|
- "│ └─Dropout: 2-16 [32, 128, 8, 8] --\n",
|
|
|
- "│ └─Conv2d: 2-17 [32, 64, 4, 4] 73,792\n",
|
|
|
- "│ └─ReLU: 2-18 [32, 64, 4, 4] --\n",
|
|
|
- "│ └─Dropout: 2-19 [32, 64, 4, 4] --\n",
|
|
|
- "│ └─Flatten: 2-20 [32, 1024] --\n",
|
|
|
- "│ └─Linear: 2-21 [32, 512] 524,800\n",
|
|
|
- "│ └─ReLU: 2-22 [32, 512] --\n",
|
|
|
- "├─Sequential: 1-2 [32, 3, 256, 256] --\n",
|
|
|
- "│ └─Linear: 2-23 [32, 1024] 525,312\n",
|
|
|
- "│ └─ReLU: 2-24 [32, 1024] --\n",
|
|
|
- "│ └─Unflatten: 2-25 [32, 64, 4, 4] --\n",
|
|
|
- "│ └─Dropout: 2-26 [32, 64, 4, 4] --\n",
|
|
|
- "│ └─ConvTranspose2d: 2-27 [32, 128, 8, 8] 131,200\n",
|
|
|
- "│ └─ReLU: 2-28 [32, 128, 8, 8] --\n",
|
|
|
- "│ └─Dropout: 2-29 [32, 128, 8, 8] --\n",
|
|
|
- "│ └─ConvTranspose2d: 2-30 [32, 64, 16, 16] 131,136\n",
|
|
|
- "│ └─ReLU: 2-31 [32, 64, 16, 16] --\n",
|
|
|
- "│ └─Dropout: 2-32 [32, 64, 16, 16] --\n",
|
|
|
- "│ └─ConvTranspose2d: 2-33 [32, 64, 32, 32] 65,600\n",
|
|
|
- "│ └─ReLU: 2-34 [32, 64, 32, 32] --\n",
|
|
|
- "│ └─Dropout: 2-35 [32, 64, 32, 32] --\n",
|
|
|
- "│ └─ConvTranspose2d: 2-36 [32, 64, 64, 64] 65,600\n",
|
|
|
- "│ └─ReLU: 2-37 [32, 64, 64, 64] --\n",
|
|
|
- "│ └─Dropout: 2-38 [32, 64, 64, 64] --\n",
|
|
|
- "│ └─ConvTranspose2d: 2-39 [32, 32, 128, 128] 73,760\n",
|
|
|
- "│ └─ReLU: 2-40 [32, 32, 128, 128] --\n",
|
|
|
- "│ └─Dropout: 2-41 [32, 32, 128, 128] --\n",
|
|
|
- "│ └─ConvTranspose2d: 2-42 [32, 16, 256, 256] 32,784\n",
|
|
|
- "│ └─ReLU: 2-43 [32, 16, 256, 256] --\n",
|
|
|
- "│ └─Dropout: 2-44 [32, 16, 256, 256] --\n",
|
|
|
- "│ └─Conv2d: 2-45 [32, 3, 256, 256] 435\n",
|
|
|
- "│ └─Tanh: 2-46 [32, 3, 256, 256] --\n",
|
|
|
- "==========================================================================================\n",
|
|
|
- "Total params: 1,828,131\n",
|
|
|
- "Trainable params: 1,828,131\n",
|
|
|
- "Non-trainable params: 0\n",
|
|
|
- "Total mult-adds (G): 131.37\n",
|
|
|
- "==========================================================================================\n",
|
|
|
- "Input size (MB): 25.17\n",
|
|
|
- "Forward/backward pass size (MB): 768.21\n",
|
|
|
- "Params size (MB): 7.31\n",
|
|
|
- "Estimated Total Size (MB): 800.69\n",
|
|
|
- "=========================================================================================="
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 7,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "summary(Autoencoder(), (32, 3, 256, 256))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- }
|
|
|
- ],
|
|
|
- "metadata": {
|
|
|
- "kernelspec": {
|
|
|
- "display_name": "Python 3.10.4 ('pytorch-gpu')",
|
|
|
- "language": "python",
|
|
|
- "name": "python3"
|
|
|
- },
|
|
|
- "language_info": {
|
|
|
- "codemirror_mode": {
|
|
|
- "name": "ipython",
|
|
|
- "version": 3
|
|
|
- },
|
|
|
- "file_extension": ".py",
|
|
|
- "mimetype": "text/x-python",
|
|
|
- "name": "python",
|
|
|
- "nbconvert_exporter": "python",
|
|
|
- "pygments_lexer": "ipython3",
|
|
|
- "version": "3.10.4"
|
|
|
- },
|
|
|
- "orig_nbformat": 4,
|
|
|
- "vscode": {
|
|
|
- "interpreter": {
|
|
|
- "hash": "17cd5c528a3345b75540c61f907eece919c031d57a2ca1e5653325af249173c9"
|
|
|
- }
|
|
|
- }
|
|
|
- },
|
|
|
- "nbformat": 4,
|
|
|
- "nbformat_minor": 2
|
|
|
-}
|