{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Comvolution / Pooling レイヤの実装" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4次元配列" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys, os\n", "sys.path.append(os.path.abspath(os.path.join('..', 'sample')))\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(10, 1, 28, 28)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = np.random.rand(10, 1, 28, 28) # ランダムにデータを生成\n", "x.shape" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 28, 28)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[0].shape # 1つ目のデータにアクセス" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 28, 28)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[1].shape # 2つ目のデータにアクセス" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[7.46455120e-01, 2.27963876e-01, 8.27426402e-01, 8.26803413e-01,\n", " 8.60531394e-01, 5.11418231e-01, 5.27047620e-01, 7.03399015e-01,\n", " 7.32704811e-02, 1.09308993e-01, 8.15014474e-01, 3.14114566e-01,\n", " 9.24025136e-01, 2.80541838e-01, 6.79584986e-03, 1.10792427e-01,\n", " 4.48246393e-01, 8.82760812e-02, 9.05718234e-01, 7.86133933e-01,\n", " 2.55209058e-01, 5.22796916e-01, 7.02683868e-01, 1.94566704e-01,\n", " 6.36620563e-01, 5.16364544e-01, 3.03474860e-01, 2.11320952e-01],\n", " [4.03951927e-01, 1.87770854e-01, 5.80080664e-01, 8.34256457e-01,\n", " 9.39314094e-01, 2.45657250e-01, 8.65530830e-01, 9.20854345e-01,\n", " 1.56331998e-01, 7.62897476e-01, 8.67818851e-01, 3.70141368e-01,\n", " 2.84505630e-01, 3.78607312e-01, 5.58315978e-01, 9.20036356e-01,\n", " 2.58880224e-01, 9.56287150e-01, 2.20749950e-01, 6.86061026e-01,\n", " 1.04820287e-01, 4.47132139e-01, 7.62985191e-01, 8.12218572e-01,\n", " 3.59674603e-01, 2.05085311e-01, 8.31086229e-01, 9.64650123e-01],\n", " [5.28956298e-01, 5.52748122e-01, 4.79920791e-01, 2.81148690e-01,\n", " 4.35900769e-01, 6.95021156e-01, 4.86216713e-01, 1.22245813e-01,\n", " 8.62867648e-01, 4.40538693e-01, 9.72116112e-01, 3.63611287e-02,\n", " 9.14045729e-01, 5.19092159e-01, 9.72350122e-01, 7.63945661e-01,\n", " 3.55482099e-01, 6.53232699e-01, 8.65791529e-01, 3.22821152e-01,\n", " 2.62456741e-01, 6.63374157e-01, 4.10376090e-01, 4.24009667e-01,\n", " 9.87274223e-01, 4.62407377e-02, 6.56456496e-01, 3.13730495e-01],\n", " [7.18606295e-01, 9.10762806e-01, 5.00487368e-02, 8.09228955e-01,\n", " 5.67282967e-01, 9.33370771e-02, 1.98947330e-01, 1.62631070e-01,\n", " 6.00090613e-01, 2.49121534e-01, 1.14705218e-01, 5.26247066e-01,\n", " 4.25716715e-01, 3.90532401e-01, 7.46411969e-01, 5.01446742e-01,\n", " 5.86019162e-01, 3.41213168e-01, 9.69051708e-01, 3.29134806e-02,\n", " 7.34018633e-01, 1.93056492e-01, 1.93553843e-01, 2.23852431e-01,\n", " 6.96258639e-01, 4.11262792e-01, 2.28029398e-01, 5.22302539e-01],\n", " [4.23475670e-01, 1.65699557e-01, 2.96014552e-01, 1.94950327e-01,\n", " 2.35840055e-01, 8.02136570e-01, 6.36881297e-01, 5.18361323e-01,\n", " 7.30289280e-01, 9.34376772e-01, 6.46827533e-01, 4.38049502e-01,\n", " 6.29735765e-01, 2.87765802e-01, 9.15467672e-01, 6.16203016e-01,\n", " 6.13656954e-01, 9.95022410e-01, 3.79982610e-01, 5.71789728e-01,\n", " 5.86652106e-01, 1.46933247e-01, 2.31418917e-01, 4.79836605e-01,\n", " 7.57234727e-01, 9.93779261e-02, 2.31958748e-01, 7.50149019e-01],\n", " [6.82747995e-01, 2.30982281e-01, 4.16005629e-02, 5.88359095e-01,\n", " 9.33558150e-01, 7.17475803e-01, 1.24411547e-01, 6.99189586e-01,\n", " 5.88954457e-01, 8.13507864e-01, 6.48615165e-01, 2.12180529e-01,\n", " 6.31167147e-01, 1.55236433e-01, 9.29129782e-01, 4.68010480e-01,\n", " 6.13947997e-01, 7.93487018e-01, 7.86107201e-01, 9.50296398e-01,\n", " 4.68254118e-01, 5.67975773e-01, 2.75214473e-01, 2.27449385e-01,\n", " 5.01343857e-01, 2.16418773e-01, 7.74640074e-01, 2.07921267e-01],\n", " [3.58466076e-01, 9.13461301e-01, 7.59217891e-02, 3.64101937e-01,\n", " 2.70674137e-01, 7.04886786e-01, 4.85662499e-01, 5.37670004e-01,\n", " 1.53445784e-01, 6.09000782e-01, 5.69970931e-01, 9.59354251e-01,\n", " 9.22195823e-01, 1.14199247e-01, 8.26350814e-01, 6.61703847e-01,\n", " 1.60396802e-01, 4.06424995e-01, 9.54833517e-01, 6.03657240e-01,\n", " 2.33391729e-01, 4.01521639e-01, 3.22540182e-01, 2.30126023e-01,\n", " 2.61640163e-01, 6.32554921e-01, 4.41241830e-01, 2.59691392e-01],\n", " [8.33443576e-01, 1.37501556e-01, 6.97651779e-01, 3.54500895e-01,\n", " 5.84440652e-01, 9.35578521e-01, 7.18540080e-01, 9.15061370e-01,\n", " 9.56877762e-01, 2.95650963e-01, 2.99195464e-01, 1.26332294e-01,\n", " 3.52259685e-01, 4.19420867e-01, 4.42227497e-01, 2.92890827e-01,\n", " 6.73741634e-01, 4.15804942e-01, 7.92465504e-01, 9.71236774e-01,\n", " 3.57554181e-01, 6.94313630e-01, 3.26568391e-01, 6.46365892e-01,\n", " 7.07729163e-01, 7.09206960e-01, 7.75490307e-01, 4.21483438e-01],\n", " [8.14915163e-03, 5.17055862e-01, 4.39072422e-01, 8.49363548e-01,\n", " 3.14305426e-02, 2.32023646e-01, 9.61729849e-01, 1.65631730e-01,\n", " 7.63607241e-01, 7.40085716e-02, 3.43271602e-02, 4.43440555e-01,\n", " 9.41311324e-01, 9.98836578e-01, 3.44661035e-01, 7.14372472e-01,\n", " 8.98675041e-02, 5.76569855e-01, 8.54737696e-01, 7.49008869e-01,\n", " 1.56247113e-01, 6.55426268e-01, 6.24778850e-01, 6.10067145e-01,\n", " 4.27663835e-01, 1.37521906e-01, 4.88969758e-01, 8.02401202e-01],\n", " [8.86015600e-01, 5.01397989e-01, 5.90601229e-02, 8.65306596e-01,\n", " 4.27583666e-01, 1.51195036e-01, 4.71110749e-02, 9.85532894e-01,\n", " 6.14647733e-01, 6.32465831e-01, 3.28685744e-01, 2.93326353e-01,\n", " 6.51783237e-01, 4.87892684e-01, 2.14292299e-01, 6.07874852e-01,\n", " 9.31987762e-01, 7.43693241e-01, 7.28212503e-01, 6.50320583e-01,\n", " 6.47876303e-01, 4.00688742e-02, 3.14540384e-01, 9.96390996e-01,\n", " 4.60466245e-01, 1.78761983e-01, 4.22862838e-02, 6.19093258e-01],\n", " [9.51484168e-01, 6.87546923e-01, 3.61576240e-01, 9.04298962e-01,\n", " 4.69256161e-01, 9.04771944e-01, 1.94895147e-01, 4.06794905e-01,\n", " 1.99110372e-01, 4.31780007e-01, 4.16901194e-01, 7.98212619e-02,\n", " 7.19721372e-01, 7.49251099e-01, 8.45644922e-01, 1.41772809e-01,\n", " 1.05602287e-01, 6.69576440e-02, 6.28086657e-01, 6.16504993e-01,\n", " 8.61918093e-01, 2.35597159e-01, 6.92240310e-01, 2.42599478e-01,\n", " 5.77458851e-01, 6.39502674e-01, 1.00905682e-01, 9.13104018e-01],\n", " [4.84151537e-01, 1.47566455e-01, 7.23894307e-01, 8.32892863e-01,\n", " 9.45434932e-01, 8.32320799e-01, 6.76105737e-01, 5.79858742e-01,\n", " 1.19561335e-01, 2.70875589e-02, 1.81584326e-01, 3.86005149e-01,\n", " 9.11250426e-01, 8.14528602e-01, 6.52829314e-01, 2.95412580e-01,\n", " 6.83016162e-01, 6.16147171e-01, 4.14563219e-04, 6.15367620e-01,\n", " 8.23406599e-02, 1.92954798e-01, 9.68460770e-01, 4.75632089e-01,\n", " 2.72164760e-01, 9.37508463e-01, 8.33939601e-01, 5.71895402e-01],\n", " [8.24197041e-01, 5.06285575e-01, 8.66060624e-01, 4.45276569e-03,\n", " 7.85483618e-02, 9.64055843e-01, 8.23314108e-01, 9.07433479e-01,\n", " 6.38287109e-01, 8.16646323e-01, 6.83621283e-01, 1.35911722e-02,\n", " 1.14781164e-01, 1.04495873e-01, 3.64866295e-02, 4.14716225e-01,\n", " 5.60912032e-02, 6.47171268e-01, 8.40610775e-01, 9.62700512e-01,\n", " 7.06501458e-02, 8.25360600e-01, 7.62349703e-01, 8.26256597e-01,\n", " 4.43647025e-01, 3.25188836e-01, 7.29252270e-01, 8.88779579e-01],\n", " [5.72501121e-01, 6.87171452e-01, 6.90831498e-01, 3.55823391e-01,\n", " 8.28055797e-01, 3.69172604e-01, 2.74057542e-01, 9.18860834e-02,\n", " 2.16264888e-01, 9.52469910e-02, 6.66412171e-01, 4.93963517e-01,\n", " 7.19060568e-01, 9.70385890e-01, 8.55261763e-02, 7.17406599e-02,\n", " 2.68435031e-01, 7.61333140e-01, 4.25854845e-01, 3.97499868e-01,\n", " 7.52547205e-01, 2.41306161e-01, 2.53409801e-01, 4.03980564e-01,\n", " 3.84571924e-01, 5.91608123e-01, 4.62163291e-01, 7.49126877e-01],\n", " [9.25316754e-01, 9.23039112e-01, 8.71130809e-02, 9.36325716e-01,\n", " 9.95759487e-01, 1.19317067e-01, 3.16613508e-01, 7.78033316e-01,\n", " 4.00319805e-01, 4.45151894e-01, 9.31948346e-01, 6.02004659e-01,\n", " 4.89465719e-01, 8.55296974e-01, 6.82682328e-02, 6.69848190e-01,\n", " 6.97620674e-02, 2.26382602e-01, 6.48364301e-01, 3.12417911e-02,\n", " 2.07617937e-01, 6.58064313e-01, 9.91172575e-01, 5.77215712e-01,\n", " 1.44586353e-01, 4.06569044e-01, 5.82562135e-01, 5.18761125e-01],\n", " [1.39701323e-01, 5.33849445e-01, 1.78333341e-02, 7.36128529e-01,\n", " 6.21878345e-01, 3.52825186e-01, 5.36497500e-03, 1.26344083e-01,\n", " 4.45264773e-01, 4.71438644e-01, 3.32636427e-01, 3.54839822e-01,\n", " 9.37085984e-01, 1.71536560e-01, 2.13848610e-02, 5.55817934e-01,\n", " 7.90145246e-01, 1.31151994e-01, 2.98448043e-01, 4.94604124e-01,\n", " 9.43620343e-01, 5.35602472e-02, 3.16143949e-01, 1.74964831e-01,\n", " 7.19499720e-01, 1.31382913e-01, 7.60031899e-02, 5.66504602e-01],\n", " [1.56467006e-01, 4.71481487e-01, 4.13741235e-01, 2.41054872e-01,\n", " 3.87437537e-01, 6.51896773e-02, 6.07919911e-01, 4.66169349e-01,\n", " 6.50662995e-01, 9.81575207e-03, 3.19298200e-01, 9.67503877e-01,\n", " 4.52597496e-03, 1.65415847e-01, 7.67148083e-01, 1.95315427e-01,\n", " 4.26667194e-01, 6.61384978e-01, 2.77242372e-01, 5.95361187e-01,\n", " 8.10458079e-01, 1.67760315e-01, 3.78285960e-01, 8.35850786e-02,\n", " 7.40909903e-01, 9.07375069e-01, 1.44170694e-01, 5.15366144e-01],\n", " [7.49555829e-03, 5.11959771e-01, 2.85293564e-01, 6.64528004e-01,\n", " 8.62378006e-01, 4.71103873e-02, 4.54527113e-01, 7.96920477e-01,\n", " 5.10112624e-01, 9.25747312e-01, 7.38245639e-01, 1.29147742e-01,\n", " 5.66681941e-01, 2.59535327e-01, 4.00606474e-01, 4.55794248e-01,\n", " 6.12936597e-01, 7.80401311e-01, 9.34560912e-01, 5.03399736e-01,\n", " 5.96915325e-01, 1.41315992e-02, 2.79090593e-01, 3.37042978e-01,\n", " 9.20900064e-01, 8.74347084e-03, 6.94411881e-01, 1.59258232e-02],\n", " [6.41399526e-01, 3.62281841e-01, 4.60624445e-01, 3.46285018e-01,\n", " 9.99380558e-01, 8.61930205e-01, 8.53832245e-01, 5.57079181e-01,\n", " 3.47734927e-01, 1.13003939e-01, 6.80945107e-01, 4.60689749e-01,\n", " 1.62382832e-01, 9.31673595e-01, 2.48753653e-01, 6.52301224e-01,\n", " 7.35557585e-01, 7.13339367e-01, 8.41276782e-01, 6.06089795e-01,\n", " 5.38195803e-01, 8.01575372e-02, 1.77487439e-01, 8.23636477e-01,\n", " 5.57477258e-01, 3.88232770e-02, 9.62549655e-01, 1.92759082e-01],\n", " [8.95211127e-01, 5.85759221e-01, 7.41528444e-02, 7.95912051e-01,\n", " 5.28049254e-01, 1.12707108e-01, 1.29677564e-01, 1.53757394e-01,\n", " 6.70237211e-01, 6.37532201e-01, 1.17483989e-01, 8.12396578e-01,\n", " 6.66003608e-01, 9.98811030e-02, 9.84756276e-01, 1.60092791e-01,\n", " 6.32020017e-01, 6.02422379e-01, 5.27858455e-02, 3.73056931e-01,\n", " 6.76505665e-01, 5.70869394e-01, 5.39477126e-01, 4.65371437e-01,\n", " 3.10655539e-01, 7.42982786e-03, 9.40493771e-01, 8.41843487e-01],\n", " [5.51324343e-01, 9.71110641e-01, 5.90230831e-01, 8.13148645e-01,\n", " 4.27976073e-01, 5.56336557e-01, 1.14200120e-01, 9.06750620e-01,\n", " 6.06534111e-01, 6.29953377e-01, 7.72502297e-01, 7.79928307e-01,\n", " 3.31396467e-01, 6.16598091e-01, 5.74269671e-01, 9.21892328e-01,\n", " 4.81680161e-01, 9.63003253e-01, 8.61772631e-02, 4.24563847e-01,\n", " 2.65460961e-01, 8.81578303e-01, 3.69491498e-01, 7.90004429e-01,\n", " 7.47649224e-01, 6.27816654e-01, 8.56111462e-01, 7.33647893e-01],\n", " [7.93071000e-01, 9.39395190e-01, 8.06050527e-01, 2.46194564e-01,\n", " 1.61274892e-01, 9.52770249e-01, 6.79299926e-01, 8.16650765e-01,\n", " 1.69010824e-01, 8.81568644e-01, 4.74255451e-01, 9.14069211e-01,\n", " 2.14361846e-01, 3.71657844e-01, 3.88492542e-01, 3.27328445e-01,\n", " 1.66717286e-01, 9.34844246e-02, 6.18202260e-01, 5.30945134e-01,\n", " 3.91600230e-01, 5.43912342e-01, 4.61725734e-01, 1.66011244e-01,\n", " 4.77644354e-01, 9.99968315e-01, 2.92467773e-01, 7.06248658e-01],\n", " [3.55233392e-01, 6.61529015e-01, 3.45724792e-01, 7.35879313e-01,\n", " 3.13486359e-01, 8.89689578e-01, 2.88355462e-01, 4.67812371e-01,\n", " 3.42406207e-01, 1.54923541e-01, 1.71601436e-02, 9.06808621e-01,\n", " 6.26267702e-01, 8.46050504e-01, 4.97191957e-01, 4.42864747e-01,\n", " 5.51167422e-01, 7.37057800e-01, 6.47698534e-01, 6.78910403e-02,\n", " 1.92058568e-01, 3.89085151e-01, 9.34290783e-01, 7.08902162e-01,\n", " 5.64528391e-01, 5.23889856e-01, 1.03287484e-01, 5.64184354e-01],\n", " [2.05301385e-01, 1.33711294e-01, 4.50556811e-01, 3.80711564e-01,\n", " 4.40834009e-01, 8.15779833e-01, 3.91492425e-01, 2.27680537e-01,\n", " 5.19823039e-01, 5.19845330e-02, 8.16212104e-01, 7.64287820e-01,\n", " 7.54757569e-01, 8.65711656e-01, 3.69304766e-01, 6.49056286e-01,\n", " 9.57079183e-02, 4.45317716e-02, 7.92986942e-01, 7.65407464e-01,\n", " 5.53620957e-01, 5.97657478e-01, 7.53313072e-01, 1.31745145e-01,\n", " 6.11002149e-01, 8.21139387e-01, 5.20170600e-01, 1.92531292e-01],\n", " [2.48764504e-01, 8.57158462e-01, 1.60636772e-01, 9.36213070e-01,\n", " 4.85975137e-01, 6.43898513e-01, 5.08286137e-02, 3.04333012e-02,\n", " 2.03940063e-01, 8.70897901e-01, 4.81609409e-01, 6.15979279e-01,\n", " 8.61882806e-01, 3.43366836e-01, 4.39118719e-01, 7.19994211e-01,\n", " 5.23686672e-01, 5.32908196e-01, 2.29339751e-01, 5.80869441e-01,\n", " 4.46953617e-02, 4.21767573e-01, 8.26056443e-01, 1.12113206e-01,\n", " 9.70801973e-01, 9.28865112e-01, 4.90708466e-01, 2.16601219e-01],\n", " [9.16337738e-01, 8.96942002e-01, 2.90703863e-01, 1.85660447e-01,\n", " 2.26900348e-01, 4.13983007e-01, 2.69570598e-01, 2.88570991e-01,\n", " 2.78554594e-01, 4.19120959e-01, 2.54393111e-01, 7.02132445e-01,\n", " 7.35633139e-01, 7.43792127e-01, 2.97035809e-01, 3.37200462e-01,\n", " 4.37576240e-01, 3.68025116e-01, 1.99905348e-01, 1.73117211e-01,\n", " 2.40863480e-01, 2.62601807e-02, 8.53133932e-01, 5.22292208e-01,\n", " 6.42894472e-01, 4.29176565e-01, 4.14385911e-01, 3.05540051e-01],\n", " [9.73592820e-01, 9.15234502e-01, 1.64450167e-01, 5.58369077e-03,\n", " 7.16045055e-02, 7.42231006e-01, 3.66223842e-01, 7.31008433e-01,\n", " 4.85227212e-01, 5.07172860e-02, 5.26328427e-01, 2.14019661e-01,\n", " 3.56974984e-01, 8.06161037e-01, 7.08070836e-01, 7.50775252e-01,\n", " 2.64836881e-01, 9.84340330e-01, 5.42772514e-01, 4.93113767e-01,\n", " 5.20820336e-01, 5.30817789e-01, 5.62515927e-02, 4.76085351e-01,\n", " 8.25129849e-01, 6.98940576e-01, 5.21246816e-01, 3.22277406e-01],\n", " [5.44108028e-01, 1.01284146e-01, 9.22323594e-01, 2.24490182e-01,\n", " 2.51889361e-01, 8.27815687e-01, 3.32006987e-01, 6.28728377e-02,\n", " 5.09628768e-02, 2.79569256e-01, 7.56242927e-01, 4.44651964e-01,\n", " 6.85801870e-02, 8.05050096e-01, 2.43117627e-02, 4.94562512e-01,\n", " 9.57729021e-01, 6.47568055e-02, 3.96393700e-01, 2.90967161e-02,\n", " 3.48168205e-01, 6.79447073e-01, 4.05632771e-03, 4.20028183e-01,\n", " 3.02600544e-01, 3.55424447e-01, 5.13069265e-01, 9.94934160e-01]])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[0, 0] # 1つ目のデータの、1チャンネル目の空間データにアクセス" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## im2colによる展開\n", "- NumPyでは、`for`分を使うと処理が遅くなってしまう\n", "- `for`分を使わず、`im2col`を使う\n", "\n", "`im2col`は image to columnの略、「画像から行列へ」" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convolution レイヤの実装" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "col1.shape: (9, 75)\n", "col2.shape: (90, 75)\n" ] } ], "source": [ "import sys, os\n", "sys.path.append(os.path.abspath(os.path.join('..', 'sample')))\n", "import numpy as np\n", "from common.util import im2col\n", "\n", "# パッチサイズが1, チャンネル3の 7x7のデータ\n", "x1 = np.random.rand(1, 3, 7, 7)\n", "col1 = im2col(x1, 5, 5, stride=1, pad=0)\n", "print(f'col1.shape: {col1.shape}')\n", "\n", "# パッチサイズが10, チャンネル3の 7x7のデータ\n", "x2 = np.random.rand(10, 3, 7, 7)\n", "col2 = im2col(x2, 5, 5, stride=1, pad=0)\n", "print(f'col2.shape: {col2.shape}')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class Comvolution:\n", " \"\"\"\n", " 畳み込み層\n", " \"\"\"\n", " def __init__(self, W, b, stride=1, pad=0):\n", " self.W = W\n", " self.b = b\n", " self.stride = stride\n", " self.pad = pad\n", " \n", " # 中間データ(backward時に使用)\n", " self.x = None \n", " self.col = None\n", " self.col_W = None\n", " \n", " # 重み・バイアスパラメータの勾配\n", " self.dW = None\n", " self.db = None\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " 順伝搬\n", " \"\"\"\n", " FN, C, FH, FW = self.W.shape\n", " N, C, H, W = x.shape\n", " out_h = int(1 + (H + 2*self.pad - FH) / self.stride) # 出力データの縦幅\n", " out_w = int(1 + (W + 2*self.pad - FW) / self.stride) # 出力データの横幅\n", " \n", " col = im2col(x, FH, FW, self.stride, self.pad) # 入力データの展開\n", " col_W = self.W.reshape(FN, -1).T # フィルターを2次元に展開\n", " out = np.dot(col, col_W) + self.b # 展開した行列の籍を計算\n", " \n", " out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)\n", " \n", " self.x = x\n", " self.col = col\n", " self.col_W = col_W\n", "\n", " return out\n", "\n", " def backward(self, dout):\n", " \"\"\"\n", " 逆伝搬\n", " \"\"\"\n", " FN, C, FH, FW = self.W.shape\n", " dout = dout.transpose(0,2,3,1).reshape(-1, FN)\n", "\n", " self.db = np.sum(dout, axis=0)\n", " self.dW = np.dot(self.col.T, dout)\n", " self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)\n", "\n", " dcol = np.dot(dout, self.col_W.T)\n", " dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)\n", "\n", " return dx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Poolingレイヤの実装" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class Pooling:\n", " \"\"\"\n", " プーリング層\n", " \"\"\"\n", " def __init__(self, pool_h, pool_w, stride=1, pad=0):\n", " self.pool_h = pool_h\n", " self.pool_w = pool_w\n", " self.stride = stride\n", " self.pad = pad\n", " \n", " self.x = None\n", " self.arg_max = None\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " 順伝搬\n", " \"\"\"\n", " N, C, H, W = x.shape\n", " out_h = int(1 + (H - self.pool_h) / self.stride) # 出力データの縦幅\n", " out_w = int(1 + (W - self.pool_w) / self.stride) # 出力データの横幅\n", "\n", " # 入力データの展開\n", " col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)\n", " col = col.reshape(-1, self.pool_h*self.pool_w)\n", "\n", " # 行ごとに最大値をとる\n", " arg_max = np.argmax(col, axis=1)\n", " out = np.max(col, axis=1)\n", " \n", " # 適切な出力サイズに整形する\n", " out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)\n", "\n", " self.x = x\n", " self.arg_max = arg_max\n", "\n", " return out\n", "\n", " def backward(self, dout):\n", " \"\"\"\n", " 逆伝搬\n", " \"\"\"\n", " dout = dout.transpose(0, 2, 3, 1)\n", " \n", " pool_size = self.pool_h * self.pool_w\n", " dmax = np.zeros((dout.size, pool_size))\n", " dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()\n", " dmax = dmax.reshape(dout.shape + (pool_size,)) \n", " \n", " dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)\n", " dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)\n", " \n", " return dx" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x.shape: (1, 2, 16, 32)\n", "\n", "dout.shape: (1, 2, 8, 16)\n", "\n", "----reshape.1----\n", "dmax.shape: (256, 4)\n", "dout.shape: (1, 8, 16, 2)\n", "(pool_size,): (4,)\n", "dmax.shape: (1, 8, 16, 2, 4)\n", "----reshape.2----\n", "dmax.shape[0]: 1\n", "dmax.shape[1]: 8\n", "dmax.shape[2]: 16\n", "dcol.shape: (128, 8)\n", "inX.shape: (1, 2, 16, 32)\n" ] } ], "source": [ "import sys, os\n", "sys.path.append(os.path.abspath(os.path.join('..', 'sample')))\n", "import numpy as np\n", "from common.util import im2col\n", "from common.util import col2im\n", "\n", "class Pooling:\n", " \"\"\"\n", " プーリング層\n", " \"\"\"\n", " def __init__(self, pool_h, pool_w, stride=1, pad=0):\n", " self.pool_h = pool_h\n", " self.pool_w = pool_w\n", " self.stride = stride\n", " self.pad = pad\n", " \n", " self.x = None\n", " self.arg_max = None\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " 順伝搬\n", " \"\"\"\n", " N, C, H, W = x.shape\n", " out_h = int(1 + (H - self.pool_h) / self.stride) # 出力データの縦幅\n", " out_w = int(1 + (W - self.pool_w) / self.stride) # 出力データの横幅\n", "\n", " # 入力データの展開\n", " col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)\n", " col = col.reshape(-1, self.pool_h*self.pool_w)\n", "\n", " # 行ごとに最大値をとる\n", " arg_max = np.argmax(col, axis=1)\n", " out = np.max(col, axis=1)\n", " \n", " # 適切な出力サイズに整形する\n", " out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)\n", "\n", " self.x = x\n", " self.arg_max = arg_max\n", "\n", " return out\n", "\n", " def backward(self, dout):\n", " \"\"\"\n", " 逆伝搬\n", " \"\"\"\n", " dout = dout.transpose(0, 2, 3, 1)\n", " \n", " pool_size = self.pool_h * self.pool_w\n", " dmax = np.zeros((dout.size, pool_size))\n", " dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()\n", " \n", " print(\"----reshape.1----\")\n", " print(f\"dmax.shape: {dmax.shape}\")\n", " print(f\"dout.shape: {dout.shape}\")\n", " print(f\"(pool_size,): {(pool_size,)}\")\n", " dmax = dmax.reshape(dout.shape + (pool_size,)) \n", " print(f\"dmax.shape: {dmax.shape}\")\n", " \n", " print(\"----reshape.2----\")\n", " print(f\"dmax.shape[0]: {dmax.shape[0]}\")\n", " print(f\"dmax.shape[1]: {dmax.shape[1]}\")\n", " print(f\"dmax.shape[2]: {dmax.shape[2]}\")\n", " dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)\n", " print(f\"dcol.shape: {dcol.shape}\")\n", "\n", " dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)\n", " \n", " return dx\n", " \n", "pool = Pooling(2, 2, stride=2)\n", "x = np.random.rand(1, 2, 16, 32)\n", "print(f\"x.shape: {x.shape}\")\n", "\n", "print()\n", "dout = pool.forward(x)\n", "print(f\"dout.shape: {dout.shape}\")\n", "\n", "print()\n", "inX = pool.backward(dout)\n", "print(f\"inX.shape: {inX.shape}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }