8.4. Comvolution / Pooling レイヤの実装¶
8.4.1. 4次元配列¶
[1]:
import sys, os
sys.path.append(os.path.abspath(os.path.join('..', 'sample')))
import numpy as np
[2]:
x = np.random.rand(10, 1, 28, 28) # ランダムにデータを生成
x.shape
[2]:
(10, 1, 28, 28)
[3]:
x[0].shape # 1つ目のデータにアクセス
[3]:
(1, 28, 28)
[4]:
x[1].shape # 2つ目のデータにアクセス
[4]:
(1, 28, 28)
[5]:
x[0, 0] # 1つ目のデータの、1チャンネル目の空間データにアクセス
[5]:
array([[7.46455120e-01, 2.27963876e-01, 8.27426402e-01, 8.26803413e-01,
8.60531394e-01, 5.11418231e-01, 5.27047620e-01, 7.03399015e-01,
7.32704811e-02, 1.09308993e-01, 8.15014474e-01, 3.14114566e-01,
9.24025136e-01, 2.80541838e-01, 6.79584986e-03, 1.10792427e-01,
4.48246393e-01, 8.82760812e-02, 9.05718234e-01, 7.86133933e-01,
2.55209058e-01, 5.22796916e-01, 7.02683868e-01, 1.94566704e-01,
6.36620563e-01, 5.16364544e-01, 3.03474860e-01, 2.11320952e-01],
[4.03951927e-01, 1.87770854e-01, 5.80080664e-01, 8.34256457e-01,
9.39314094e-01, 2.45657250e-01, 8.65530830e-01, 9.20854345e-01,
1.56331998e-01, 7.62897476e-01, 8.67818851e-01, 3.70141368e-01,
2.84505630e-01, 3.78607312e-01, 5.58315978e-01, 9.20036356e-01,
2.58880224e-01, 9.56287150e-01, 2.20749950e-01, 6.86061026e-01,
1.04820287e-01, 4.47132139e-01, 7.62985191e-01, 8.12218572e-01,
3.59674603e-01, 2.05085311e-01, 8.31086229e-01, 9.64650123e-01],
[5.28956298e-01, 5.52748122e-01, 4.79920791e-01, 2.81148690e-01,
4.35900769e-01, 6.95021156e-01, 4.86216713e-01, 1.22245813e-01,
8.62867648e-01, 4.40538693e-01, 9.72116112e-01, 3.63611287e-02,
9.14045729e-01, 5.19092159e-01, 9.72350122e-01, 7.63945661e-01,
3.55482099e-01, 6.53232699e-01, 8.65791529e-01, 3.22821152e-01,
2.62456741e-01, 6.63374157e-01, 4.10376090e-01, 4.24009667e-01,
9.87274223e-01, 4.62407377e-02, 6.56456496e-01, 3.13730495e-01],
[7.18606295e-01, 9.10762806e-01, 5.00487368e-02, 8.09228955e-01,
5.67282967e-01, 9.33370771e-02, 1.98947330e-01, 1.62631070e-01,
6.00090613e-01, 2.49121534e-01, 1.14705218e-01, 5.26247066e-01,
4.25716715e-01, 3.90532401e-01, 7.46411969e-01, 5.01446742e-01,
5.86019162e-01, 3.41213168e-01, 9.69051708e-01, 3.29134806e-02,
7.34018633e-01, 1.93056492e-01, 1.93553843e-01, 2.23852431e-01,
6.96258639e-01, 4.11262792e-01, 2.28029398e-01, 5.22302539e-01],
[4.23475670e-01, 1.65699557e-01, 2.96014552e-01, 1.94950327e-01,
2.35840055e-01, 8.02136570e-01, 6.36881297e-01, 5.18361323e-01,
7.30289280e-01, 9.34376772e-01, 6.46827533e-01, 4.38049502e-01,
6.29735765e-01, 2.87765802e-01, 9.15467672e-01, 6.16203016e-01,
6.13656954e-01, 9.95022410e-01, 3.79982610e-01, 5.71789728e-01,
5.86652106e-01, 1.46933247e-01, 2.31418917e-01, 4.79836605e-01,
7.57234727e-01, 9.93779261e-02, 2.31958748e-01, 7.50149019e-01],
[6.82747995e-01, 2.30982281e-01, 4.16005629e-02, 5.88359095e-01,
9.33558150e-01, 7.17475803e-01, 1.24411547e-01, 6.99189586e-01,
5.88954457e-01, 8.13507864e-01, 6.48615165e-01, 2.12180529e-01,
6.31167147e-01, 1.55236433e-01, 9.29129782e-01, 4.68010480e-01,
6.13947997e-01, 7.93487018e-01, 7.86107201e-01, 9.50296398e-01,
4.68254118e-01, 5.67975773e-01, 2.75214473e-01, 2.27449385e-01,
5.01343857e-01, 2.16418773e-01, 7.74640074e-01, 2.07921267e-01],
[3.58466076e-01, 9.13461301e-01, 7.59217891e-02, 3.64101937e-01,
2.70674137e-01, 7.04886786e-01, 4.85662499e-01, 5.37670004e-01,
1.53445784e-01, 6.09000782e-01, 5.69970931e-01, 9.59354251e-01,
9.22195823e-01, 1.14199247e-01, 8.26350814e-01, 6.61703847e-01,
1.60396802e-01, 4.06424995e-01, 9.54833517e-01, 6.03657240e-01,
2.33391729e-01, 4.01521639e-01, 3.22540182e-01, 2.30126023e-01,
2.61640163e-01, 6.32554921e-01, 4.41241830e-01, 2.59691392e-01],
[8.33443576e-01, 1.37501556e-01, 6.97651779e-01, 3.54500895e-01,
5.84440652e-01, 9.35578521e-01, 7.18540080e-01, 9.15061370e-01,
9.56877762e-01, 2.95650963e-01, 2.99195464e-01, 1.26332294e-01,
3.52259685e-01, 4.19420867e-01, 4.42227497e-01, 2.92890827e-01,
6.73741634e-01, 4.15804942e-01, 7.92465504e-01, 9.71236774e-01,
3.57554181e-01, 6.94313630e-01, 3.26568391e-01, 6.46365892e-01,
7.07729163e-01, 7.09206960e-01, 7.75490307e-01, 4.21483438e-01],
[8.14915163e-03, 5.17055862e-01, 4.39072422e-01, 8.49363548e-01,
3.14305426e-02, 2.32023646e-01, 9.61729849e-01, 1.65631730e-01,
7.63607241e-01, 7.40085716e-02, 3.43271602e-02, 4.43440555e-01,
9.41311324e-01, 9.98836578e-01, 3.44661035e-01, 7.14372472e-01,
8.98675041e-02, 5.76569855e-01, 8.54737696e-01, 7.49008869e-01,
1.56247113e-01, 6.55426268e-01, 6.24778850e-01, 6.10067145e-01,
4.27663835e-01, 1.37521906e-01, 4.88969758e-01, 8.02401202e-01],
[8.86015600e-01, 5.01397989e-01, 5.90601229e-02, 8.65306596e-01,
4.27583666e-01, 1.51195036e-01, 4.71110749e-02, 9.85532894e-01,
6.14647733e-01, 6.32465831e-01, 3.28685744e-01, 2.93326353e-01,
6.51783237e-01, 4.87892684e-01, 2.14292299e-01, 6.07874852e-01,
9.31987762e-01, 7.43693241e-01, 7.28212503e-01, 6.50320583e-01,
6.47876303e-01, 4.00688742e-02, 3.14540384e-01, 9.96390996e-01,
4.60466245e-01, 1.78761983e-01, 4.22862838e-02, 6.19093258e-01],
[9.51484168e-01, 6.87546923e-01, 3.61576240e-01, 9.04298962e-01,
4.69256161e-01, 9.04771944e-01, 1.94895147e-01, 4.06794905e-01,
1.99110372e-01, 4.31780007e-01, 4.16901194e-01, 7.98212619e-02,
7.19721372e-01, 7.49251099e-01, 8.45644922e-01, 1.41772809e-01,
1.05602287e-01, 6.69576440e-02, 6.28086657e-01, 6.16504993e-01,
8.61918093e-01, 2.35597159e-01, 6.92240310e-01, 2.42599478e-01,
5.77458851e-01, 6.39502674e-01, 1.00905682e-01, 9.13104018e-01],
[4.84151537e-01, 1.47566455e-01, 7.23894307e-01, 8.32892863e-01,
9.45434932e-01, 8.32320799e-01, 6.76105737e-01, 5.79858742e-01,
1.19561335e-01, 2.70875589e-02, 1.81584326e-01, 3.86005149e-01,
9.11250426e-01, 8.14528602e-01, 6.52829314e-01, 2.95412580e-01,
6.83016162e-01, 6.16147171e-01, 4.14563219e-04, 6.15367620e-01,
8.23406599e-02, 1.92954798e-01, 9.68460770e-01, 4.75632089e-01,
2.72164760e-01, 9.37508463e-01, 8.33939601e-01, 5.71895402e-01],
[8.24197041e-01, 5.06285575e-01, 8.66060624e-01, 4.45276569e-03,
7.85483618e-02, 9.64055843e-01, 8.23314108e-01, 9.07433479e-01,
6.38287109e-01, 8.16646323e-01, 6.83621283e-01, 1.35911722e-02,
1.14781164e-01, 1.04495873e-01, 3.64866295e-02, 4.14716225e-01,
5.60912032e-02, 6.47171268e-01, 8.40610775e-01, 9.62700512e-01,
7.06501458e-02, 8.25360600e-01, 7.62349703e-01, 8.26256597e-01,
4.43647025e-01, 3.25188836e-01, 7.29252270e-01, 8.88779579e-01],
[5.72501121e-01, 6.87171452e-01, 6.90831498e-01, 3.55823391e-01,
8.28055797e-01, 3.69172604e-01, 2.74057542e-01, 9.18860834e-02,
2.16264888e-01, 9.52469910e-02, 6.66412171e-01, 4.93963517e-01,
7.19060568e-01, 9.70385890e-01, 8.55261763e-02, 7.17406599e-02,
2.68435031e-01, 7.61333140e-01, 4.25854845e-01, 3.97499868e-01,
7.52547205e-01, 2.41306161e-01, 2.53409801e-01, 4.03980564e-01,
3.84571924e-01, 5.91608123e-01, 4.62163291e-01, 7.49126877e-01],
[9.25316754e-01, 9.23039112e-01, 8.71130809e-02, 9.36325716e-01,
9.95759487e-01, 1.19317067e-01, 3.16613508e-01, 7.78033316e-01,
4.00319805e-01, 4.45151894e-01, 9.31948346e-01, 6.02004659e-01,
4.89465719e-01, 8.55296974e-01, 6.82682328e-02, 6.69848190e-01,
6.97620674e-02, 2.26382602e-01, 6.48364301e-01, 3.12417911e-02,
2.07617937e-01, 6.58064313e-01, 9.91172575e-01, 5.77215712e-01,
1.44586353e-01, 4.06569044e-01, 5.82562135e-01, 5.18761125e-01],
[1.39701323e-01, 5.33849445e-01, 1.78333341e-02, 7.36128529e-01,
6.21878345e-01, 3.52825186e-01, 5.36497500e-03, 1.26344083e-01,
4.45264773e-01, 4.71438644e-01, 3.32636427e-01, 3.54839822e-01,
9.37085984e-01, 1.71536560e-01, 2.13848610e-02, 5.55817934e-01,
7.90145246e-01, 1.31151994e-01, 2.98448043e-01, 4.94604124e-01,
9.43620343e-01, 5.35602472e-02, 3.16143949e-01, 1.74964831e-01,
7.19499720e-01, 1.31382913e-01, 7.60031899e-02, 5.66504602e-01],
[1.56467006e-01, 4.71481487e-01, 4.13741235e-01, 2.41054872e-01,
3.87437537e-01, 6.51896773e-02, 6.07919911e-01, 4.66169349e-01,
6.50662995e-01, 9.81575207e-03, 3.19298200e-01, 9.67503877e-01,
4.52597496e-03, 1.65415847e-01, 7.67148083e-01, 1.95315427e-01,
4.26667194e-01, 6.61384978e-01, 2.77242372e-01, 5.95361187e-01,
8.10458079e-01, 1.67760315e-01, 3.78285960e-01, 8.35850786e-02,
7.40909903e-01, 9.07375069e-01, 1.44170694e-01, 5.15366144e-01],
[7.49555829e-03, 5.11959771e-01, 2.85293564e-01, 6.64528004e-01,
8.62378006e-01, 4.71103873e-02, 4.54527113e-01, 7.96920477e-01,
5.10112624e-01, 9.25747312e-01, 7.38245639e-01, 1.29147742e-01,
5.66681941e-01, 2.59535327e-01, 4.00606474e-01, 4.55794248e-01,
6.12936597e-01, 7.80401311e-01, 9.34560912e-01, 5.03399736e-01,
5.96915325e-01, 1.41315992e-02, 2.79090593e-01, 3.37042978e-01,
9.20900064e-01, 8.74347084e-03, 6.94411881e-01, 1.59258232e-02],
[6.41399526e-01, 3.62281841e-01, 4.60624445e-01, 3.46285018e-01,
9.99380558e-01, 8.61930205e-01, 8.53832245e-01, 5.57079181e-01,
3.47734927e-01, 1.13003939e-01, 6.80945107e-01, 4.60689749e-01,
1.62382832e-01, 9.31673595e-01, 2.48753653e-01, 6.52301224e-01,
7.35557585e-01, 7.13339367e-01, 8.41276782e-01, 6.06089795e-01,
5.38195803e-01, 8.01575372e-02, 1.77487439e-01, 8.23636477e-01,
5.57477258e-01, 3.88232770e-02, 9.62549655e-01, 1.92759082e-01],
[8.95211127e-01, 5.85759221e-01, 7.41528444e-02, 7.95912051e-01,
5.28049254e-01, 1.12707108e-01, 1.29677564e-01, 1.53757394e-01,
6.70237211e-01, 6.37532201e-01, 1.17483989e-01, 8.12396578e-01,
6.66003608e-01, 9.98811030e-02, 9.84756276e-01, 1.60092791e-01,
6.32020017e-01, 6.02422379e-01, 5.27858455e-02, 3.73056931e-01,
6.76505665e-01, 5.70869394e-01, 5.39477126e-01, 4.65371437e-01,
3.10655539e-01, 7.42982786e-03, 9.40493771e-01, 8.41843487e-01],
[5.51324343e-01, 9.71110641e-01, 5.90230831e-01, 8.13148645e-01,
4.27976073e-01, 5.56336557e-01, 1.14200120e-01, 9.06750620e-01,
6.06534111e-01, 6.29953377e-01, 7.72502297e-01, 7.79928307e-01,
3.31396467e-01, 6.16598091e-01, 5.74269671e-01, 9.21892328e-01,
4.81680161e-01, 9.63003253e-01, 8.61772631e-02, 4.24563847e-01,
2.65460961e-01, 8.81578303e-01, 3.69491498e-01, 7.90004429e-01,
7.47649224e-01, 6.27816654e-01, 8.56111462e-01, 7.33647893e-01],
[7.93071000e-01, 9.39395190e-01, 8.06050527e-01, 2.46194564e-01,
1.61274892e-01, 9.52770249e-01, 6.79299926e-01, 8.16650765e-01,
1.69010824e-01, 8.81568644e-01, 4.74255451e-01, 9.14069211e-01,
2.14361846e-01, 3.71657844e-01, 3.88492542e-01, 3.27328445e-01,
1.66717286e-01, 9.34844246e-02, 6.18202260e-01, 5.30945134e-01,
3.91600230e-01, 5.43912342e-01, 4.61725734e-01, 1.66011244e-01,
4.77644354e-01, 9.99968315e-01, 2.92467773e-01, 7.06248658e-01],
[3.55233392e-01, 6.61529015e-01, 3.45724792e-01, 7.35879313e-01,
3.13486359e-01, 8.89689578e-01, 2.88355462e-01, 4.67812371e-01,
3.42406207e-01, 1.54923541e-01, 1.71601436e-02, 9.06808621e-01,
6.26267702e-01, 8.46050504e-01, 4.97191957e-01, 4.42864747e-01,
5.51167422e-01, 7.37057800e-01, 6.47698534e-01, 6.78910403e-02,
1.92058568e-01, 3.89085151e-01, 9.34290783e-01, 7.08902162e-01,
5.64528391e-01, 5.23889856e-01, 1.03287484e-01, 5.64184354e-01],
[2.05301385e-01, 1.33711294e-01, 4.50556811e-01, 3.80711564e-01,
4.40834009e-01, 8.15779833e-01, 3.91492425e-01, 2.27680537e-01,
5.19823039e-01, 5.19845330e-02, 8.16212104e-01, 7.64287820e-01,
7.54757569e-01, 8.65711656e-01, 3.69304766e-01, 6.49056286e-01,
9.57079183e-02, 4.45317716e-02, 7.92986942e-01, 7.65407464e-01,
5.53620957e-01, 5.97657478e-01, 7.53313072e-01, 1.31745145e-01,
6.11002149e-01, 8.21139387e-01, 5.20170600e-01, 1.92531292e-01],
[2.48764504e-01, 8.57158462e-01, 1.60636772e-01, 9.36213070e-01,
4.85975137e-01, 6.43898513e-01, 5.08286137e-02, 3.04333012e-02,
2.03940063e-01, 8.70897901e-01, 4.81609409e-01, 6.15979279e-01,
8.61882806e-01, 3.43366836e-01, 4.39118719e-01, 7.19994211e-01,
5.23686672e-01, 5.32908196e-01, 2.29339751e-01, 5.80869441e-01,
4.46953617e-02, 4.21767573e-01, 8.26056443e-01, 1.12113206e-01,
9.70801973e-01, 9.28865112e-01, 4.90708466e-01, 2.16601219e-01],
[9.16337738e-01, 8.96942002e-01, 2.90703863e-01, 1.85660447e-01,
2.26900348e-01, 4.13983007e-01, 2.69570598e-01, 2.88570991e-01,
2.78554594e-01, 4.19120959e-01, 2.54393111e-01, 7.02132445e-01,
7.35633139e-01, 7.43792127e-01, 2.97035809e-01, 3.37200462e-01,
4.37576240e-01, 3.68025116e-01, 1.99905348e-01, 1.73117211e-01,
2.40863480e-01, 2.62601807e-02, 8.53133932e-01, 5.22292208e-01,
6.42894472e-01, 4.29176565e-01, 4.14385911e-01, 3.05540051e-01],
[9.73592820e-01, 9.15234502e-01, 1.64450167e-01, 5.58369077e-03,
7.16045055e-02, 7.42231006e-01, 3.66223842e-01, 7.31008433e-01,
4.85227212e-01, 5.07172860e-02, 5.26328427e-01, 2.14019661e-01,
3.56974984e-01, 8.06161037e-01, 7.08070836e-01, 7.50775252e-01,
2.64836881e-01, 9.84340330e-01, 5.42772514e-01, 4.93113767e-01,
5.20820336e-01, 5.30817789e-01, 5.62515927e-02, 4.76085351e-01,
8.25129849e-01, 6.98940576e-01, 5.21246816e-01, 3.22277406e-01],
[5.44108028e-01, 1.01284146e-01, 9.22323594e-01, 2.24490182e-01,
2.51889361e-01, 8.27815687e-01, 3.32006987e-01, 6.28728377e-02,
5.09628768e-02, 2.79569256e-01, 7.56242927e-01, 4.44651964e-01,
6.85801870e-02, 8.05050096e-01, 2.43117627e-02, 4.94562512e-01,
9.57729021e-01, 6.47568055e-02, 3.96393700e-01, 2.90967161e-02,
3.48168205e-01, 6.79447073e-01, 4.05632771e-03, 4.20028183e-01,
3.02600544e-01, 3.55424447e-01, 5.13069265e-01, 9.94934160e-01]])
8.4.2. im2colによる展開¶
NumPyでは、
for
分を使うと処理が遅くなってしまうfor
分を使わず、im2col
を使う
im2col
は image to columnの略、「画像から行列へ」
8.4.3. Convolution レイヤの実装¶
[6]:
import sys, os
sys.path.append(os.path.abspath(os.path.join('..', 'sample')))
import numpy as np
from common.util import im2col
# パッチサイズが1, チャンネル3の 7x7のデータ
x1 = np.random.rand(1, 3, 7, 7)
col1 = im2col(x1, 5, 5, stride=1, pad=0)
print(f'col1.shape: {col1.shape}')
# パッチサイズが10, チャンネル3の 7x7のデータ
x2 = np.random.rand(10, 3, 7, 7)
col2 = im2col(x2, 5, 5, stride=1, pad=0)
print(f'col2.shape: {col2.shape}')
col1.shape: (9, 75)
col2.shape: (90, 75)
[7]:
class Comvolution:
"""
畳み込み層
"""
def __init__(self, W, b, stride=1, pad=0):
self.W = W
self.b = b
self.stride = stride
self.pad = pad
# 中間データ(backward時に使用)
self.x = None
self.col = None
self.col_W = None
# 重み・バイアスパラメータの勾配
self.dW = None
self.db = None
def forward(self, x):
"""
順伝搬
"""
FN, C, FH, FW = self.W.shape
N, C, H, W = x.shape
out_h = int(1 + (H + 2*self.pad - FH) / self.stride) # 出力データの縦幅
out_w = int(1 + (W + 2*self.pad - FW) / self.stride) # 出力データの横幅
col = im2col(x, FH, FW, self.stride, self.pad) # 入力データの展開
col_W = self.W.reshape(FN, -1).T # フィルターを2次元に展開
out = np.dot(col, col_W) + self.b # 展開した行列の籍を計算
out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
self.x = x
self.col = col
self.col_W = col_W
return out
def backward(self, dout):
"""
逆伝搬
"""
FN, C, FH, FW = self.W.shape
dout = dout.transpose(0,2,3,1).reshape(-1, FN)
self.db = np.sum(dout, axis=0)
self.dW = np.dot(self.col.T, dout)
self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)
dcol = np.dot(dout, self.col_W.T)
dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)
return dx
8.4.4. Poolingレイヤの実装¶
[8]:
class Pooling:
"""
プーリング層
"""
def __init__(self, pool_h, pool_w, stride=1, pad=0):
self.pool_h = pool_h
self.pool_w = pool_w
self.stride = stride
self.pad = pad
self.x = None
self.arg_max = None
def forward(self, x):
"""
順伝搬
"""
N, C, H, W = x.shape
out_h = int(1 + (H - self.pool_h) / self.stride) # 出力データの縦幅
out_w = int(1 + (W - self.pool_w) / self.stride) # 出力データの横幅
# 入力データの展開
col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
col = col.reshape(-1, self.pool_h*self.pool_w)
# 行ごとに最大値をとる
arg_max = np.argmax(col, axis=1)
out = np.max(col, axis=1)
# 適切な出力サイズに整形する
out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)
self.x = x
self.arg_max = arg_max
return out
def backward(self, dout):
"""
逆伝搬
"""
dout = dout.transpose(0, 2, 3, 1)
pool_size = self.pool_h * self.pool_w
dmax = np.zeros((dout.size, pool_size))
dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
dmax = dmax.reshape(dout.shape + (pool_size,))
dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
return dx
[24]:
import sys, os
sys.path.append(os.path.abspath(os.path.join('..', 'sample')))
import numpy as np
from common.util import im2col
from common.util import col2im
class Pooling:
"""
プーリング層
"""
def __init__(self, pool_h, pool_w, stride=1, pad=0):
self.pool_h = pool_h
self.pool_w = pool_w
self.stride = stride
self.pad = pad
self.x = None
self.arg_max = None
def forward(self, x):
"""
順伝搬
"""
N, C, H, W = x.shape
out_h = int(1 + (H - self.pool_h) / self.stride) # 出力データの縦幅
out_w = int(1 + (W - self.pool_w) / self.stride) # 出力データの横幅
# 入力データの展開
col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
col = col.reshape(-1, self.pool_h*self.pool_w)
# 行ごとに最大値をとる
arg_max = np.argmax(col, axis=1)
out = np.max(col, axis=1)
# 適切な出力サイズに整形する
out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)
self.x = x
self.arg_max = arg_max
return out
def backward(self, dout):
"""
逆伝搬
"""
dout = dout.transpose(0, 2, 3, 1)
pool_size = self.pool_h * self.pool_w
dmax = np.zeros((dout.size, pool_size))
dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
print("----reshape.1----")
print(f"dmax.shape: {dmax.shape}")
print(f"dout.shape: {dout.shape}")
print(f"(pool_size,): {(pool_size,)}")
dmax = dmax.reshape(dout.shape + (pool_size,))
print(f"dmax.shape: {dmax.shape}")
print("----reshape.2----")
print(f"dmax.shape[0]: {dmax.shape[0]}")
print(f"dmax.shape[1]: {dmax.shape[1]}")
print(f"dmax.shape[2]: {dmax.shape[2]}")
dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
print(f"dcol.shape: {dcol.shape}")
dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
return dx
pool = Pooling(2, 2, stride=2)
x = np.random.rand(1, 2, 16, 32)
print(f"x.shape: {x.shape}")
print()
dout = pool.forward(x)
print(f"dout.shape: {dout.shape}")
print()
inX = pool.backward(dout)
print(f"inX.shape: {inX.shape}")
x.shape: (1, 2, 16, 32)
dout.shape: (1, 2, 8, 16)
----reshape.1----
dmax.shape: (256, 4)
dout.shape: (1, 8, 16, 2)
(pool_size,): (4,)
dmax.shape: (1, 8, 16, 2, 4)
----reshape.2----
dmax.shape[0]: 1
dmax.shape[1]: 8
dmax.shape[2]: 16
dcol.shape: (128, 8)
inX.shape: (1, 2, 16, 32)