{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 手書き数字認識" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "機械学習で問題を解くときは、「学習」と「推論」の2つのフェーズで行います。 \n", "ニューラルネットワークでは、「学習」は「訓練データ(学習データ)」を使用して重みパラメーターの学習を行い、「推論」では学習した重みパラメーターを使って、入力データの分類を行います。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MNISTデータセット\n", "\n", "今回は、MNISTデータセット (えむにすと)と呼ばれる、機械学習の分野で最も有名なデータを使います。\n", "\n", "MNISTデータセットは、0から9までの数字画像から構成されています。\n", "\n", "* 28 x 28のグレー画像 (1チャンネル)\n", "* 1ピクセルに 0 ~ 255 までの値" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## データを表示してみる" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. include:: ./6.MNIST/mnist_show.py" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "label: 5\n", "img.shape: (784,)\n", "img.shape: (28, 28)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAN80lEQVR4nO3df6hcdXrH8c+ncf3DrBpTMYasNhuRWBWbLRqLSl2RrD9QNOqWDVgsBrN/GHChhEr6xyolEuqP0qAsuYu6sWyzLqgYZVkVo6ZFCF5j1JjU1YrdjV6SSozG+KtJnv5xT+Su3vnOzcyZOZP7vF9wmZnzzJnzcLife87Md879OiIEYPL7k6YbANAfhB1IgrADSRB2IAnCDiRxRD83ZpuP/oEeiwiPt7yrI7vtS22/aftt27d281oAesudjrPbniLpd5IWSNou6SVJiyJia2EdjuxAj/XiyD5f0tsR8U5EfCnpV5Ku6uL1APRQN2GfJekPYx5vr5b9EdtLbA/bHu5iWwC61M0HdOOdKnzjND0ihiQNSZzGA03q5si+XdJJYx5/R9L73bUDoFe6CftLkk61/V3bR0r6kaR19bQFoG4dn8ZHxD7bSyU9JWmKpAci4o3aOgNQq46H3jraGO/ZgZ7ryZdqABw+CDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii4ymbcXiYMmVKsX7sscf2dPtLly5tWTvqqKOK686dO7dYv/nmm4v1u+66q2Vt0aJFxXU///zzYn3lypXF+u23316sN6GrsNt+V9IeSfsl7YuIs+toCkD96jiyXxQRH9TwOgB6iPfsQBLdhj0kPW37ZdtLxnuC7SW2h20Pd7ktAF3o9jT+/Ih43/YJkp6x/V8RsWHsEyJiSNKQJNmOLrcHoENdHdkj4v3qdqekxyTNr6MpAPXrOOy2p9o++uB9ST+QtKWuxgDUq5vT+BmSHrN98HX+PSJ+W0tXk8zJJ59crB955JHF+nnnnVesX3DBBS1r06ZNK6577bXXFutN2r59e7G+atWqYn3hwoUta3v27Cmu++qrrxbrL7zwQrE+iDoOe0S8I+kvauwFQA8x9AYkQdiBJAg7kARhB5Ig7EASjujfl9om6zfo5s2bV6yvX7++WO/1ZaaD6sCBA8X6jTfeWKx/8sknHW97ZGSkWP/www+L9TfffLPjbfdaRHi85RzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlrMH369GJ948aNxfqcOXPqbKdW7XrfvXt3sX7RRRe1rH355ZfFdbN+/6BbjLMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBJM2VyDXbt2FevLli0r1q+44opi/ZVXXinW2/1L5ZLNmzcX6wsWLCjW9+7dW6yfccYZLWu33HJLcV3UiyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTB9ewD4JhjjinW200vvHr16pa1xYsXF9e9/vrri/W1a9cW6xg8HV/PbvsB2zttbxmzbLrtZ2y/Vd0eV2ezAOo3kdP4X0i69GvLbpX0bEScKunZ6jGAAdY27BGxQdLXvw96laQ11f01kq6uuS8ANev0u/EzImJEkiJixPYJrZ5oe4mkJR1uB0BNen4hTEQMSRqS+IAOaFKnQ287bM+UpOp2Z30tAeiFTsO+TtIN1f0bJD1eTzsAeqXtabzttZK+L+l429sl/VTSSkm/tr1Y0u8l/bCXTU52H3/8cVfrf/TRRx2ve9NNNxXrDz/8cLHebo51DI62YY+IRS1KF9fcC4Ae4uuyQBKEHUiCsANJEHYgCcIOJMElrpPA1KlTW9aeeOKJ4roXXnhhsX7ZZZcV608//XSxjv5jymYgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9knulFNOKdY3bdpUrO/evbtYf+6554r14eHhlrX77ruvuG4/fzcnE8bZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtmTW7hwYbH+4IMPFutHH310x9tevnx5sf7QQw8V6yMjIx1vezJjnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcHUVnnnlmsX7PPfcU6xdf3Plkv6tXry7WV6xYUay/9957HW/7cNbxOLvtB2zvtL1lzLLbbL9ne3P1c3mdzQKo30RO438h6dJxlv9LRMyrfn5Tb1sA6tY27BGxQdKuPvQCoIe6+YBuqe3XqtP841o9yfYS28O2W/8zMgA912nYfybpFEnzJI1IurvVEyNiKCLOjoizO9wWgBp0FPaI2BER+yPigKSfS5pfb1sA6tZR2G3PHPNwoaQtrZ4LYDC0HWe3vVbS9yUdL2mHpJ9Wj+dJCknvSvpxRLS9uJhx9sln2rRpxfqVV17ZstbuWnl73OHir6xfv75YX7BgQbE+WbUaZz9iAisuGmfx/V13BKCv+LoskARhB5Ig7EAShB1IgrADSXCJKxrzxRdfFOtHHFEeLNq3b1+xfskll7SsPf/888V1D2f8K2kgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLtVW/I7ayzzirWr7vuumL9nHPOaVlrN47eztatW4v1DRs2dPX6kw1HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2SW7u3LnF+tKlS4v1a665plg/8cQTD7mnidq/f3+xPjJS/u/lBw4cqLOdwx5HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2w0C7sexFi8abaHdUu3H02bNnd9JSLYaHh4v1FStWFOvr1q2rs51Jr+2R3fZJtp+zvc32G7ZvqZZPt/2M7beq2+N63y6ATk3kNH6fpL+PiD+X9FeSbrZ9uqRbJT0bEadKerZ6DGBAtQ17RIxExKbq/h5J2yTNknSVpDXV09ZIurpXTQLo3iG9Z7c9W9L3JG2UNCMiRqTRPwi2T2ixzhJJS7prE0C3Jhx229+W9Iikn0TEx/a4c8d9Q0QMSRqqXoOJHYGGTGjozfa3NBr0X0bEo9XiHbZnVvWZknb2pkUAdWh7ZPfoIfx+Sdsi4p4xpXWSbpC0srp9vCcdTgIzZswo1k8//fRi/d577y3WTzvttEPuqS4bN24s1u+8886WtccfL//KcIlqvSZyGn++pL+V9LrtzdWy5RoN+a9tL5b0e0k/7E2LAOrQNuwR8Z+SWr1Bv7jedgD0Cl+XBZIg7EAShB1IgrADSRB2IAkucZ2g6dOnt6ytXr26uO68efOK9Tlz5nTUUx1efPHFYv3uu+8u1p966qli/bPPPjvkntAbHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IIk04+znnntusb5s2bJiff78+S1rs2bN6qinunz66acta6tWrSque8cddxTre/fu7agnDB6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRJpx9oULF3ZV78bWrVuL9SeffLJY37dvX7FeuuZ89+7dxXWRB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUjCEVF+gn2SpIcknSjpgKShiPhX27dJuknS/1ZPXR4Rv2nzWuWNAehaRIw76/JEwj5T0syI2GT7aEkvS7pa0t9I+iQi7ppoE4Qd6L1WYZ/I/Owjkkaq+3tsb5PU7L9mAXDIDuk9u+3Zkr4naWO1aKnt12w/YPu4FusssT1se7irTgF0pe1p/FdPtL8t6QVJKyLiUdszJH0gKST9k0ZP9W9s8xqcxgM91vF7dkmy/S1JT0p6KiLuGac+W9KTEXFmm9ch7ECPtQp729N425Z0v6RtY4NefXB30EJJW7ptEkDvTOTT+Ask/Yek1zU69CZJyyUtkjRPo6fx70r6cfVhXum1OLIDPdbVaXxdCDvQex2fxgOYHAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HvK5g8k/c+Yx8dXywbRoPY2qH1J9NapOnv7s1aFvl7P/o2N28MRcXZjDRQMam+D2pdEb53qV2+cxgNJEHYgiabDPtTw9ksGtbdB7Uuit071pbdG37MD6J+mj+wA+oSwA0k0Enbbl9p+0/bbtm9toodWbL9r+3Xbm5uen66aQ2+n7S1jlk23/Yztt6rbcefYa6i322y/V+27zbYvb6i3k2w/Z3ub7Tds31Itb3TfFfrqy37r+3t221Mk/U7SAknbJb0kaVFEbO1rIy3YflfS2RHR+BcwbP+1pE8kPXRwai3b/yxpV0SsrP5QHhcR/zAgvd2mQ5zGu0e9tZpm/O/U4L6rc/rzTjRxZJ8v6e2IeCcivpT0K0lXNdDHwIuIDZJ2fW3xVZLWVPfXaPSXpe9a9DYQImIkIjZV9/dIOjjNeKP7rtBXXzQR9lmS/jDm8XYN1nzvIelp2y/bXtJ0M+OYcXCarer2hIb7+bq203j309emGR+YfdfJ9OfdaiLs401NM0jjf+dHxF9KukzSzdXpKibmZ5JO0egcgCOS7m6ymWqa8Uck/SQiPm6yl7HG6asv+62JsG+XdNKYx9+R9H4DfYwrIt6vbndKekyjbzsGyY6DM+hWtzsb7ucrEbEjIvZHxAFJP1eD+66aZvwRSb+MiEerxY3vu/H66td+ayLsL0k61fZ3bR8p6UeS1jXQxzfYnlp9cCLbUyX9QIM3FfU6STdU92+Q9HiDvfyRQZnGu9U042p43zU+/XlE9P1H0uUa/UT+vyX9YxM9tOhrjqRXq583mu5N0lqNntb9n0bPiBZL+lNJz0p6q7qdPkC9/ZtGp/Z+TaPBmtlQbxdo9K3ha5I2Vz+XN73vCn31Zb/xdVkgCb5BByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/D+f1mbtgJ8kQQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# coding: utf-8\n", "import sys, os\n", "sys.path.append(os.path.abspath(os.path.join('..', 'sample')))\n", "\n", "import numpy as np\n", "from dataset.mnist import load_mnist # サンプルにあるPythonモジュール\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "def img_show(img):\n", " '''\n", " NumPy用になっているデータをPIL用イメージ画像に変換して、表示します\n", " '''\n", " pil_img = Image.fromarray(np.uint8(img))\n", " # pil_img.show()\n", " plt.imshow(pil_img, cmap='gray')\n", "\n", "# load_mnist で読み込む\n", "# (訓練画像、訓練ラベル), (テスト画像, テストラベル) という形式で MNISTデータを返す\n", "(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)\n", "\n", "# 訓練画像と訓練ラベルを取り出す\n", "img = x_train[0]\n", "label = t_train[0]\n", "print(\"label: \", label)\n", "\n", "# 訓練画像は flatten=True…データを1次元配列にしている\n", "print(\"img.shape: \", img.shape)\n", "\n", "# 形状を元の画像サイズに変形\n", "img = img.reshape(28, 28) \n", "print(\"img.shape: \", img.shape)\n", "\n", "# 画像を表示する\n", "img_show(img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ニューラルネットワークの推論処理\n", "\n", "* 入力層は784\n", " * 画像の大きさ$28 \\times 28 = 784$より\n", "* 出力層は10\n", " * 0 ~ 9 の数字を出すため\n", "\n", "* 隠れ層は2つ、一つは50, もう一つは100\n", " * 任意の値で設定可能" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy:0.9352\n" ] } ], "source": [ "# coding: utf-8\n", "import sys, os\n", "sys.path.append(os.path.abspath(os.path.join('..', 'sample')))\n", "import numpy as np\n", "import pickle\n", "from dataset.mnist import load_mnist\n", "from common.functions import sigmoid, softmax\n", "\n", "\n", "def get_data():\n", " # 正規化されたデータとして前処理を行う\n", " (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)\n", " return x_test, t_test\n", "\n", "\n", "def init_network():\n", " with open(os.path.abspath(os.path.join('..', 'sample', 'ch03', \"sample_weight.pkl\")), 'rb') as f:\n", " network = pickle.load(f)\n", " return network\n", "\n", "\n", "def predict(network, x):\n", " W1, W2, W3 = network['W1'], network['W2'], network['W3']\n", " b1, b2, b3 = network['b1'], network['b2'], network['b3']\n", "\n", " a1 = np.dot(x, W1) + b1\n", " z1 = sigmoid(a1)\n", " a2 = np.dot(z1, W2) + b2\n", " z2 = sigmoid(a2)\n", " a3 = np.dot(z2, W3) + b3\n", " y = softmax(a3)\n", "\n", " return y\n", "\n", "\n", "# MNIST データセットの取得\n", "x, t = get_data()\n", "\n", "# ニューラルネットワークの構築\n", "network = init_network()\n", "\n", "# MNISTデータの画像を分類し\n", "# 確率の高いものを予測結果に入れる\n", "accuracy_cnt = 0\n", "for i in range(len(x)):\n", " y = predict(network, x[i])\n", " p= np.argmax(y) # 最も確率の高い要素のインデックスを取得\n", " if p == t[i]:\n", " accuracy_cnt += 1\n", "\n", "# ニューラルネットワークの認識制度\n", "print(\"Accuracy:\" + str(float(accuracy_cnt) / len(x)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## バッチ処理\n", "\n", "* 入力データを複数にまとめて、一度に計算させる\n", "* 一つ一つ計算するよりは高速化が可能\n", "* まとまった入力データを「バッチ」という\n", "* バッチが多いと逆に遅くなったり、そもそもメモリが足りなくなったりする場合もあるので、必ずバッチ処理をしなければならない、というわけではない" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy:0.9352\n" ] } ], "source": [ "x, t = get_data()\n", "network = init_network()\n", "\n", "batch_size = 100 # バッチの数\n", "accuracy_cnt = 0\n", "\n", "for i in range(0, len(x), batch_size):\n", " x_batch = x[i:i+batch_size]\n", " y_batch = predict(network, x_batch)\n", " p = np.argmax(y_batch, axis=1)\n", " accuracy_cnt += np.sum(p == t[i:i+batch_size])\n", "\n", "print(\"Accuracy:\" + str(float(accuracy_cnt) / len(x)))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }