{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 単純なレイヤの実装" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 乗算レイヤーの実装" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "class MulLayer:\n", " def __init__(self):\n", " \"\"\"\n", " 入力値の保存をするメンバ変数を初期化します\n", "\n", " Attributes\n", " ----------\n", " x : int\n", " 入力1\n", " y : int\n", " 入力2\n", " \"\"\"\n", " self.x = None\n", " self.y = None\n", "\n", " def forward(self, x, y):\n", " \"\"\"\n", " 順伝播の計算をします(乗算)\n", " \n", " Parameters\n", " ----------\n", " x : int\n", " 入力1\n", " y : int\n", " 入力2\n", "\n", " Returns\n", " -------\n", " out : int\n", " 計算結果\n", " \"\"\"\n", " self.x = x\n", " self.y = y \n", " out = x * y\n", "\n", " return out\n", "\n", " def backward(self, dout):\n", " \"\"\"\n", " 逆伝播の計算をします(乗算)\n", " \n", " Parameters\n", " ----------\n", " dout : int\n", " 逆伝播で上流から伝わってきた微分\n", "\n", " Returns\n", " -------\n", " dx : int\n", " 入力2の微分\n", " dy : int\n", " 入力1の微分\n", " \"\"\"\n", " dx = dout * self.y\n", " dy = dout * self.x\n", "\n", " return dx, dy" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "price: 220\n", "dApple: 2.2\n", "dApple_num: 110\n", "dTax: 200\n" ] } ], "source": [ "apple = 100\n", "apple_num = 2\n", "tax = 1.1\n", "\n", "mul_apple_layer = MulLayer()\n", "mul_tax_layer = MulLayer()\n", "\n", "# forward\n", "apple_price = mul_apple_layer.forward(apple, apple_num)\n", "price = mul_tax_layer.forward(apple_price, tax)\n", "\n", "# backward\n", "dprice = 1\n", "dapple_price, dtax = mul_tax_layer.backward(dprice)\n", "dapple, dapple_num = mul_apple_layer.backward(dapple_price)\n", "\n", "print(\"price:\", int(price))\n", "print(\"dApple:\", dapple)\n", "print(\"dApple_num:\", int(dapple_num))\n", "print(\"dTax:\", dtax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 加算レイヤの実装" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class AddLayer:\n", " def __init__(self):\n", " \"\"\"\n", " 何もしない\n", " \"\"\"\n", " pass\n", "\n", " def forward(self, x, y):\n", " \"\"\"\n", " 順伝播の計算をします(加算)\n", " \n", " Parameters\n", " ----------\n", " x : int\n", " 入力1\n", " y : int\n", " 入力2\n", "\n", " Returns\n", " -------\n", " out : int\n", " 計算結果\n", " \"\"\"\n", " out = x + y\n", "\n", " return out\n", "\n", " def backward(self, dout):\n", " \"\"\"\n", " 逆伝播の計算をします(乗算)\n", " \n", " Parameters\n", " ----------\n", " dout : int\n", " 逆伝播で上流から伝わってきた微分\n", "\n", " Returns\n", " -------\n", " dx : int\n", " 入力2の微分\n", " dy : int\n", " 入力1の微分\n", " \"\"\"\n", " dx = dout * 1\n", " dy = dout * 1\n", "\n", " return dx, dy" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "price: 715\n", "dApple: 2.2\n", "dApple_num: 110\n", "dOrange: 3.3000000000000003\n", "dOrange_num: 165\n", "dTax: 650\n" ] } ], "source": [ "apple = 100\n", "apple_num = 2\n", "orange = 150\n", "orange_num = 3\n", "tax = 1.1\n", "\n", "# layer\n", "mul_apple_layer = MulLayer()\n", "mul_orange_layer = MulLayer()\n", "add_apple_orange_layer = AddLayer()\n", "mul_tax_layer = MulLayer()\n", "\n", "# forward\n", "apple_price = mul_apple_layer.forward(apple, apple_num) # (1)\n", "orange_price = mul_orange_layer.forward(orange, orange_num) # (2)\n", "all_price = add_apple_orange_layer.forward(apple_price, orange_price) # (3)\n", "price = mul_tax_layer.forward(all_price, tax) # (4)\n", "\n", "# backward\n", "dprice = 1\n", "dall_price, dtax = mul_tax_layer.backward(dprice) # (4)\n", "dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price) # (3)\n", "dorange, dorange_num = mul_orange_layer.backward(dorange_price) # (2)\n", "dapple, dapple_num = mul_apple_layer.backward(dapple_price) # (1)\n", "\n", "print(\"price:\", int(price))\n", "print(\"dApple:\", dapple)\n", "print(\"dApple_num:\", int(dapple_num))\n", "print(\"dOrange:\", dorange)\n", "print(\"dOrange_num:\", int(dorange_num))\n", "print(\"dTax:\", dtax)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }