{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 損失関数\n", "\n", "* ニューラルネットワークでは、一つの指標を手掛かりに最適なパラメータを探索する\n", "* ニューラルネットワークの指標は **損失関数** (loss function)\n", "* 損失関数は、ニューラルネットワークの性能の悪さを示す" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2乗和誤差 (mean squared error)\n", "\n", "$$\n", " E = \\frac{1}{2} \\sum_{k}^{} (y_k-t_k)^2\n", "$$" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "def mean_squared_error(y, t):\n", " return 0.5 * np.sum((y-t)**2)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.09750000000000003" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 「2」を正解とする\n", "t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]\n", "\n", "# 「2」の確立が最も高い場合\n", "y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]\n", "mean_squared_error(np.array(y), np.array(t))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5975" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 「7」の確率が最も高い場合\n", "y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]\n", "mean_squared_error(np.array(y), np.array(t))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* この値から、一つ目の損失関数が小さいことがわかる" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 交差エントロピー誤差 (cross entropy error)\n", "\n", "$$\n", " E = - \\sum_{k}^{} t_k\\log{y_k}\n", "$$" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def cross_entropy_error(y, t):\n", " delta = 1e-7\n", " return -np.sum(t * np.log(y + delta))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* `delta` は、`np.log(0)` の計算時に`-inf`にならないための防止策" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.510825457099338" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 「2」を正解とする\n", "t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]\n", "\n", "# 「2」の確立が最も高い場合\n", "y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]\n", "cross_entropy_error(np.array(y), np.array(t))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2.302584092994546" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 「7」の確率が最も高い場合\n", "y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]\n", "cross_entropy_error(np.array(y), np.array(t))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ミニバッチ学習\n", "\n", "訓練データすべての損失関数の和は、交差エントロピー誤差の場合\n", "\n", "$$\n", " E = -\\frac{1}{N} \\sum_{n}\\sum_{k} t_{nk}\\log{y_{nk}}\n", "$$\n", "\n", "* $t_{nk}$は$n$個目のデータの、$k$番目の値のこと\n", "* $y_{nk}$はニューラルネットワークの出力\n", "* $t_{nk}$は訓練データ\n", "\n", "訓練データすべてを計算するには時間がかかってしまうので、データの一部を選んで「近似」として学習させることを、ミニバッチ学習という。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(60000, 784)\n", "(60000, 10)\n" ] } ], "source": [ "import sys, os\n", "sys.path.append(os.path.abspath(os.path.join('..', 'sample')))\n", "\n", "import numpy as np\n", "from dataset.mnist import load_mnist # サンプルにあるPythonモジュール\n", "\n", "(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)\n", "\n", "print(x_train.shape)\n", "print(t_train.shape)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def cross_entropy_error(y, t):\n", " \"\"\"\n", " 交差エントロピー誤差を出力します\n", " \n", " Parameters\n", " ----------\n", " y : np.array\n", " ニューラルネットワークの出力\n", " t : np.array\n", " 教師データ\n", " \"\"\"\n", " if y.ndim == 1:\n", " t = t.reshape(1, t.size)\n", " y = y.reshape(1, y.size)\n", " \n", " batch_size = y.shape[0]\n", " \n", " # 教師データがone-hot表現のとき\n", " # return -np.sum(t * np.log(y) / batch_size)\n", "\n", " # 教師データがラベルの場合\n", " return -np.sum(np.log(y[np.arange(batch_size), t])) / batch_size" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## なぜ損失関数を設定するのか\n", "\n", "* パラメータの微分(正確には勾配)を計算し、その値と損失関数の値を比較するため\n", "* 微分の値がマイナスであれば、重みパラメータをプラスに補正する\n", "* 微分の値がプラスであれば、重みパラメータをマイナスに補正する\n", "* 微分の値が0であれば、損失関数の値も変化しないので更新をストップする\n", "\n", "### なぜ認識制度ではないのか\n", "\n", "例えば、100枚中32枚を正しく認識できたら、認識制度が32%になる。しかし、重みパラメータの値を少し変えただけでは認識制度が32%のままになる (32/100なので、細かい変化がつかめない)。仮に認識制度が向上しても、100枚中33枚、つまり33%と、値が整数でしか測れない。\n", "\n", "一方、損失関数は100%中何%が外れだったのか(ざ、ざっくり過ぎる説明…)がわかる" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }