{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Unsupervised Learning: clustering\n", "\n", "notebook adapted from: https://github.com/ageron/handson-ml/blob/master/08_dimensionality_reduction.ipynb\n", "\n", "See Machine Learning avec Scikit-Learn, Mise en oeuvre et cas concrets, by Aurélien Géron" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Common imports\n", "import numpy as np\n", "import os\n", "\n", "# to make this notebook's output stable across runs\n", "np.random.seed(42)\n", "\n", "# To plot pretty figures\n", "%matplotlib notebook \n", "#inline\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "plt.rcParams['axes.labelsize'] = 14\n", "plt.rcParams['xtick.labelsize'] = 12\n", "plt.rcParams['ytick.labelsize'] = 12\n", "#plt.rcParams['savefig.dpi'] = 300\n", "#plt.rcParams['figure.dpi'] = 150\n", "\n", "def save_fig(fig_id, tight_layout=True):\n", " path = os.path.join(fig_id + \".png\")\n", " print(\"Saving figure\", fig_id)\n", " if tight_layout:\n", " plt.tight_layout()\n", " plt.savefig(path, format='png', dpi=300)\n", " \n", "\n", "# Ignore useless warnings (see SciPy issue #5998)\n", "import warnings\n", "warnings.filterwarnings(action=\"ignore\", module=\"scipy\", message=\"^internal gelsd\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Projection methods\n", "Build 3D dataset:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from mpl_toolkits.mplot3d import Axes3D\n", "\n", "fig = plt.figure()\n", "ax = fig.add_subplot(111, projection='3d')\n", "\n", "X3D_above = X[X[:, 2] > X3D_inv[:, 2]]\n", "X3D_below = X[X[:, 2] <= X3D_inv[:, 2]]\n", "\n", "ax.plot(X3D_below[:, 0], X3D_below[:, 1], X3D_below[:, 2], \"bo\", alpha=0.5)\n", "\n", "ax.plot_surface(x1, x2, z, alpha=0.2, color=\"k\")\n", "np.linalg.norm(C, axis=0)\n", "ax.add_artist(Arrow3D([0, C[0, 0]],[0, C[0, 1]],[0, C[0, 2]], mutation_scale=15, lw=1, arrowstyle=\"-|>\", color=\"k\"))\n", "ax.add_artist(Arrow3D([0, C[1, 0]],[0, C[1, 1]],[0, C[1, 2]], mutation_scale=15, lw=1, arrowstyle=\"-|>\", color=\"k\"))\n", "ax.plot([0], [0], [0], \"k.\")\n", "\n", "for i in range(m):\n", " if X[i, 2] > X3D_inv[i, 2]:\n", " ax.plot([X[i][0], X3D_inv[i][0]], [X[i][1], X3D_inv[i][1]], [X[i][2], X3D_inv[i][2]], \"k-\")\n", " else:\n", " ax.plot([X[i][0], X3D_inv[i][0]], [X[i][1], X3D_inv[i][1]], [X[i][2], X3D_inv[i][2]], \"k-\", color=\"#505050\")\n", " \n", "ax.plot(X3D_inv[:, 0], X3D_inv[:, 1], X3D_inv[:, 2], \"k+\")\n", "ax.plot(X3D_inv[:, 0], X3D_inv[:, 1], X3D_inv[:, 2], \"k.\")\n", "ax.plot(X3D_above[:, 0], X3D_above[:, 1], X3D_above[:, 2], \"bo\")\n", "ax.set_xlabel(\"$x_1$\", fontsize=18)\n", "ax.set_ylabel(\"$x_2$\", fontsize=18)\n", "ax.set_zlabel(\"$x_3$\", fontsize=18)\n", "ax.set_xlim(axes[0:2])\n", "ax.set_ylim(axes[2:4])\n", "ax.set_zlim(axes[4:6])\n", "\n", "#save_fig(\"dataset_3d_plot\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/aurelien/anaconda3/lib/python3.6/site-packages/matplotlib/cbook/__init__.py:424: MatplotlibDeprecationWarning: \n", "Passing one of 'on', 'true', 'off', 'false' as a boolean is deprecated; use an actual boolean (True/False) instead.\n", " warn_deprecated(\"2.2\", \"Passing one of 'on', 'true', 'off', 'false' as a \"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saving figure classification_vs_clustering_diagram\n" ] } ], "source": [ "plt.figure(figsize=(9, 3.5))\n", "\n", "plt.subplot(121)\n", "plt.plot(X[y==0, 0], X[y==0, 1], \"yo\", label=\"Iris-Setosa\")\n", "plt.plot(X[y==1, 0], X[y==1, 1], \"bs\", label=\"Iris-Versicolor\")\n", "plt.plot(X[y==2, 0], X[y==2, 1], \"g^\", label=\"Iris-Virginica\")\n", "plt.xlabel(\"Petal length\", fontsize=14)\n", "plt.ylabel(\"Petal width\", fontsize=14)\n", "plt.legend(fontsize=12)\n", "\n", "plt.subplot(122)\n", "plt.scatter(X[:, 0], X[:, 1], c=\"k\", marker=\".\")\n", "plt.xlabel(\"Petal length\", fontsize=14)\n", "plt.tick_params(labelleft='off')\n", "\n", "save_fig(\"classification_vs_clustering_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A Gaussian mixture model (explained below) can actually separate these clusters pretty well." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "from sklearn.mixture import GaussianMixture" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "y_pred = GaussianMixture(n_components=3, random_state=42).fit(X).predict(X)\n", "mapping = np.array([2, 0, 1])\n", "y_pred = np.array([mapping[cluster_id] for cluster_id in y_pred])" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "plt.plot(X[y_pred==0, 0], X[y_pred==0, 1], \"yo\", label=\"Cluster 1\")\n", "plt.plot(X[y_pred==1, 0], X[y_pred==1, 1], \"bs\", label=\"Cluster 2\")\n", "plt.plot(X[y_pred==2, 0], X[y_pred==2, 1], \"g^\", label=\"Cluster 3\")\n", "plt.xlabel(\"Petal length\", fontsize=14)\n", "plt.ylabel(\"Petal width\", fontsize=14)\n", "plt.legend(loc=\"upper right\", fontsize=12)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "145" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.sum(y_pred==y)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9666666666666667" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.sum(y_pred==y) / len(y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## K-Means" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by generating some blobs:" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import make_blobs" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "blob_centers = np.array(\n", " [[ 0.2, 2.3],\n", " [-1.5 , 2.3],\n", " [-2.8, 1.8],\n", " [-2.8, 2.8],\n", " [-2.8, 1.3]])\n", "blob_std = np.array([0.4, 0.3, 0.1, 0.1, 0.1])" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "X, y = make_blobs(n_samples=2000, centers=blob_centers,\n", " cluster_std=blob_std, random_state=7)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's plot them:" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "def plot_clusters(X, y=None):\n", " plt.scatter(X[:, 0], X[:, 1], c=y, s=1)\n", " plt.xlabel(\"$x_1$\", fontsize=14)\n", " plt.ylabel(\"$x_2$\", fontsize=14, rotation=0)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Saving figure voronoi_diagram\n" ] } ], "source": [ "plt.figure(figsize=(8, 4))\n", "plot_decision_boundaries(kmeans, X)\n", "save_fig(\"voronoi_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Not bad! Some of the instances near the edges were probably assigned to the wrong cluster, but overall it looks pretty good." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Hard Clustering _vs_ Soft Clustering" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Rather than arbitrarily choosing the closest cluster for each instance, which is called _hard clustering_, it might be better measure the distance of each instance to all 5 centroids. This is what the `transform()` method does:" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[2.81093633, 0.32995317, 2.9042344 , 1.49439034, 2.88633901],\n", " [5.80730058, 2.80290755, 5.84739223, 4.4759332 , 5.84236351],\n", " [1.21475352, 3.29399768, 0.29040966, 1.69136631, 1.71086031],\n", " [0.72581411, 3.21806371, 0.36159148, 1.54808703, 1.21567622]])" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "kmeans.transform(X_new)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can verify that this is indeed the Euclidian distance between each instance and each centroid:" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[2.81093633, 0.32995317, 2.9042344 , 1.49439034, 2.88633901],\n", " [5.80730058, 2.80290755, 5.84739223, 4.4759332 , 5.84236351],\n", " [1.21475352, 3.29399768, 0.29040966, 1.69136631, 1.71086031],\n", " [0.72581411, 3.21806371, 0.36159148, 1.54808703, 1.21567622]])" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.linalg.norm(np.tile(X_new, (1, k)).reshape(-1, k, 2) - kmeans.cluster_centers_, axis=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### K-Means Algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The K-Means algorithm is one of the fastest clustering algorithms, but also one of the simplest:\n", "* First initialize $k$ centroids randomly: $k$ distinct instances are chosen randomly from the dataset and the centroids are placed at their locations.\n", "* Repeat until convergence (i.e., until the centroids stop moving):\n", " * Assign each instance to the closest centroid.\n", " * Update the centroids to be the mean of the instances that are assigned to them." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `KMeans` class applies an optimized algorithm by default. To get the original K-Means algorithm (for educational purposes only), you must set `init=\"random\"`, `n_init=1`and `algorithm=\"full\"`. These hyperparameters will be explained below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's run the K-Means algorithm for 1, 2 and 3 iterations, to see how the centroids move around:" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "KMeans(algorithm='full', copy_x=True, init='random', max_iter=3, n_clusters=5,\n", " n_init=1, n_jobs=None, precompute_distances='auto', random_state=1,\n", " tol=0.0001, verbose=0)" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "kmeans_iter1 = KMeans(n_clusters=5, init=\"random\", n_init=1,\n", " algorithm=\"full\", max_iter=1, random_state=1)\n", "kmeans_iter2 = KMeans(n_clusters=5, init=\"random\", n_init=1,\n", " algorithm=\"full\", max_iter=2, random_state=1)\n", "kmeans_iter3 = KMeans(n_clusters=5, init=\"random\", n_init=1,\n", " algorithm=\"full\", max_iter=3, random_state=1)\n", "kmeans_iter1.fit(X)\n", "kmeans_iter2.fit(X)\n", "kmeans_iter3.fit(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And let's plot this:" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " event.shiftKey = false;\n", " // Send a \"J\" for go to next cell\n", " event.which = 74;\n", " event.keyCode = 74;\n", " manager.command_mode();\n", " manager.handle_keydown(event);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/aurelien/anaconda3/lib/python3.6/site-packages/matplotlib/cbook/__init__.py:424: MatplotlibDeprecationWarning: \n", "Passing one of 'on', 'true', 'off', 'false' as a boolean is deprecated; use an actual boolean (True/False) instead.\n", " warn_deprecated(\"2.2\", \"Passing one of 'on', 'true', 'off', 'false' as a \"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saving figure kmeans_variability_diagram\n" ] } ], "source": [ "kmeans_rnd_init1 = KMeans(n_clusters=5, init=\"random\", n_init=1,\n", " algorithm=\"full\", random_state=11)\n", "kmeans_rnd_init2 = KMeans(n_clusters=5, init=\"random\", n_init=1,\n", " algorithm=\"full\", random_state=19)\n", "\n", "plot_clusterer_comparison(kmeans_rnd_init1, kmeans_rnd_init2, X,\n", " \"Solution 1\", \"Solution 2 (with a different random init)\")\n", "\n", "save_fig(\"kmeans_variability_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inertia" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To select the best model, we will need a way to evaluate a K-Mean model's performance. Unfortunately, clustering is an unsupervised task, so we do not have the targets. But at least we can measure the distance between each instance and its centroid. This is the idea behind the _inertia_ metric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kmeans.inertia_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can easily verify, inertia is the sum of the squared distances between each training instance and its closest centroid:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_dist = kmeans.transform(X)\n", "np.sum(X_dist[np.arange(len(X_dist)), kmeans.labels_]**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `score()` method returns the negative inertia. Why negative? Well, it is because a predictor's `score()` method must always respect the \"_great is better_\" rule." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kmeans.score(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multiple Initializations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So one approach to solve the variability issue is to simply run the K-Means algorithm multiple times with different random initializations, and select the solution that minimizes the inertia. For example, here are the inertias of the two \"bad\" models shown in the previous figure:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kmeans_rnd_init1.inertia_" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kmeans_rnd_init2.inertia_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, they have a higher inertia than the first \"good\" model we trained, which means they are probably worse." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When you set the `n_init` hyperparameter, Scikit-Learn runs the original algorithm `n_init` times, and selects the solution that minimizes the inertia. By default, Scikit-Learn sets `n_init=10`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kmeans_rnd_10_inits = KMeans(n_clusters=5, init=\"random\", n_init=10,\n", " algorithm=\"full\", random_state=11)\n", "kmeans_rnd_10_inits.fit(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, we end up with the initial model, which is certainly the optimal K-Means solution (at least in terms of inertia, and assuming $k=5$)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 4))\n", "plot_decision_boundaries(kmeans_rnd_10_inits, X)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### K-Means++" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Instead of initializing the centroids entirely randomly, it is preferable to initialize them using the following algorithm, proposed in a [2006 paper](https://goo.gl/eNUPw6) by David Arthur and Sergei Vassilvitskii:\n", "* Take one centroid $c_1$, chosen uniformly at random from the dataset.\n", "* Take a new center $c_i$, choosing an instance $\\mathbf{x}_i$ with probability: $D(\\mathbf{x}_i)^2$ / $\\sum\\limits_{j=1}^{m}{D(\\mathbf{x}_j)}^2$ where $D(\\mathbf{x}_i)$ is the distance between the instance $\\mathbf{x}_i$ and the closest centroid that was already chosen. This probability distribution ensures that instances that are further away from already chosen centroids are much more likely be selected as centroids.\n", "* Repeat the previous step until all $k$ centroids have been chosen." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The rest of the K-Means++ algorithm is just regular K-Means. With this initialization, the K-Means algorithm is much less likely to converge to a suboptimal solution, so it is possible to reduce `n_init` considerably. Most of the time, this largely compensates for the additional complexity of the initialization process." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To set the initialization to K-Means++, simply set `init=\"k-means++\"` (this is actually the default):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "KMeans()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "good_init = np.array([[-3, 3], [-3, 2], [-3, 1], [-1, 2], [0, 2]])\n", "kmeans = KMeans(n_clusters=5, init=good_init, n_init=1, random_state=42)\n", "kmeans.fit(X)\n", "kmeans.inertia_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Accelerated K-Means" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The K-Means algorithm can be significantly accelerated by avoiding many unnecessary distance calculations: this is achieved by exploiting the triangle inequality (given three points A, B and C, the distance AC is always such that AC ≤ AB + BC) and by keeping track of lower and upper bounds for distances between instances and centroids (see this [2003 paper](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf) by Charles Elkan for more details)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use Elkan's variant of K-Means, just set `algorithm=\"elkan\"`. Note that it does not support sparse data, so by default, Scikit-Learn uses `\"elkan\"` for dense data, and `\"full\"` (the regular K-Means algorithm) for sparse data. (*should take over 1 minute*)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%timeit -n 50 KMeans(algorithm=\"elkan\").fit(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%timeit -n 50 KMeans(algorithm=\"full\").fit(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Finding the optimal number of clusters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What if the number of clusters was set to a lower or greater value than 5?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kmeans_k3 = KMeans(n_clusters=3, random_state=42)\n", "kmeans_k8 = KMeans(n_clusters=8, random_state=42)\n", "\n", "plot_clusterer_comparison(kmeans_k3, kmeans_k8, X, \"$k=3$\", \"$k=8$\")\n", "save_fig(\"bad_n_clusters_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ouch, these two models don't look great. What about their inertias?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kmeans_k3.inertia_" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kmeans_k8.inertia_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "No, we cannot simply take the value of $k$ that minimizes the inertia, since it keeps getting lower as we increase $k$. Indeed, the more clusters there are, the closer each instance will be to its closest centroid, and therefore the lower the inertia will be. However, we can plot the inertia as a function of $k$ and analyze the resulting curve:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kmeans_per_k = [KMeans(n_clusters=k, random_state=42).fit(X)\n", " for k in range(1, 10)]\n", "inertias = [model.inertia_ for model in kmeans_per_k]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 3.5))\n", "plt.plot(range(1, 10), inertias, \"bo-\")\n", "plt.xlabel(\"$k$\", fontsize=14)\n", "plt.ylabel(\"Inertia\", fontsize=14)\n", "plt.annotate('Elbow',\n", " xy=(4, inertias[3]),\n", " xytext=(0.55, 0.55),\n", " textcoords='figure fraction',\n", " fontsize=16,\n", " arrowprops=dict(facecolor='black', shrink=0.1)\n", " )\n", "plt.axis([1, 8.5, 0, 1300])\n", "save_fig(\"inertia_vs_k_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, there is an elbow at $k=4$, which means that less clusters than that would be bad, and more clusters would not help much and might cut clusters in half. So $k=4$ is a pretty good choice. Of course in this example it is not perfect since it means that the two blobs in the lower left will be considered as just a single cluster, but it's a pretty good clustering nonetheless." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_decision_boundaries(kmeans_per_k[4-1], X)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another approach is to look at the _silhouette score_, which is the mean _silhouette coefficient_ over all the instances. An instance's silhouette coefficient is equal to $(b - a)/\\max(a, b)$ where $a$ is the mean distance to the other instances in the same cluster (it is the _mean intra-cluster distance_), and $b$ is the _mean nearest-cluster distance_, that is the mean distance to the instances of the next closest cluster (defined as the one that minimizes $b$, excluding the instance's own cluster). The silhouette coefficient can vary between -1 and +1: a coefficient close to +1 means that the instance is well inside its own cluster and far from other clusters, while a coefficient close to 0 means that it is close to a cluster boundary, and finally a coefficient close to -1 means that the instance may have been assigned to the wrong cluster." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's plot the silhouette score as a function of $k$:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import silhouette_score" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "silhouette_score(X, kmeans.labels_)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "silhouette_scores = [silhouette_score(X, model.labels_)\n", " for model in kmeans_per_k[1:]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 3))\n", "plt.plot(range(2, 10), silhouette_scores, \"bo-\")\n", "plt.xlabel(\"$k$\", fontsize=14)\n", "plt.ylabel(\"Silhouette score\", fontsize=14)\n", "plt.axis([1.8, 8.5, 0.55, 0.7])\n", "save_fig(\"silhouette_score_vs_k_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, this visualization is much richer than the previous one: in particular, although it confirms that $k=4$ is a very good choice, but it also underlines the fact that $k=5$ is quite good as well." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An even more informative visualization is given when you plot every instance's silhouette coefficient, sorted by the cluster they are assigned to and by the value of the coefficient. This is called a _silhouette diagram_:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import silhouette_samples\n", "from matplotlib.ticker import FixedLocator, FixedFormatter\n", "\n", "plt.figure(figsize=(11, 9))\n", "\n", "for k in (3, 4, 5, 6):\n", " plt.subplot(2, 2, k - 2)\n", " \n", " y_pred = kmeans_per_k[k - 1].labels_\n", " silhouette_coefficients = silhouette_samples(X, y_pred)\n", "\n", " padding = len(X) // 30\n", " pos = padding\n", " ticks = []\n", " for i in range(k):\n", " coeffs = silhouette_coefficients[y_pred == i]\n", " coeffs.sort()\n", "\n", " color = matplotlib.cm.spectral(i / k)\n", " plt.fill_betweenx(np.arange(pos, pos + len(coeffs)), 0, coeffs,\n", " facecolor=color, edgecolor=color, alpha=0.7)\n", " ticks.append(pos + len(coeffs) // 2)\n", " pos += len(coeffs) + padding\n", "\n", " plt.gca().yaxis.set_major_locator(FixedLocator(ticks))\n", " plt.gca().yaxis.set_major_formatter(FixedFormatter(range(k)))\n", " if k in (3, 5):\n", " plt.ylabel(\"Cluster\")\n", " \n", " if k in (5, 6):\n", " plt.gca().set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])\n", " plt.xlabel(\"Silhouette Coefficient\")\n", " else:\n", " plt.tick_params(labelbottom='off')\n", "\n", " plt.axvline(x=silhouette_scores[k - 2], color=\"red\", linestyle=\"--\")\n", " plt.title(\"$k={}$\".format(k), fontsize=16)\n", "\n", "save_fig(\"silhouette_analysis_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Limits of K-Means" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X1, y1 = make_blobs(n_samples=1000, centers=((4, -4), (0, 0)), random_state=42)\n", "X1 = X1.dot(np.array([[0.374, 0.95], [0.732, 0.598]]))\n", "X2, y2 = make_blobs(n_samples=250, centers=1, random_state=42)\n", "X2 = X2 + [6, -8]\n", "X = np.r_[X1, X2]\n", "y = np.r_[y1, y2]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_clusters(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "kmeans_good = KMeans(n_clusters=3, init=np.array([[-1.5, 2.5], [0.5, 0], [4, 0]]), n_init=1, random_state=42)\n", "kmeans_bad = KMeans(n_clusters=3, random_state=42)\n", "kmeans_good.fit(X)\n", "kmeans_bad.fit(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "plt.figure(figsize=(10, 3.2))\n", "\n", "plt.subplot(121)\n", "plot_decision_boundaries(kmeans_good, X)\n", "plt.title(\"Inertia = {:.1f}\".format(kmeans_good.inertia_), fontsize=14)\n", "\n", "plt.subplot(122)\n", "plot_decision_boundaries(kmeans_bad, X, show_ylabels=False)\n", "plt.title(\"Inertia = {:.1f}\".format(kmeans_bad.inertia_), fontsize=14)\n", "\n", "save_fig(\"bad_kmeans_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using clustering for image segmentation\n", "\n", "\n", "![Iris](ladybug.png)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib.image import imread\n", "image = imread(os.path.join(\"ladybug.png\"))\n", "image.shape\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X = image.reshape(-1, 3)\n", "kmeans = KMeans(n_clusters=8, random_state=42).fit(X)\n", "segmented_img = kmeans.cluster_centers_[kmeans.labels_]\n", "segmented_img = segmented_img.reshape(image.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "segmented_imgs = []\n", "n_colors = (10, 8, 6, 4, 2)\n", "for n_clusters in n_colors:\n", " kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X)\n", " segmented_img = kmeans.cluster_centers_[kmeans.labels_]\n", " segmented_imgs.append(segmented_img.reshape(image.shape))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "plt.figure(figsize=(10,5))\n", "plt.subplots_adjust(wspace=0.05, hspace=0.1)\n", "\n", "plt.subplot(231)\n", "plt.imshow(image)\n", "plt.title(\"Original image\")\n", "plt.axis('off')\n", "\n", "for idx, n_clusters in enumerate(n_colors):\n", " plt.subplot(232 + idx)\n", " plt.imshow(segmented_imgs[idx])\n", " plt.title(\"{} colors\".format(n_clusters))\n", " plt.axis('off')\n", "\n", "save_fig('image_segmentation_diagram', tight_layout=False)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DBSCAN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import make_moons" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X, y = make_moons(n_samples=1000, noise=0.05, random_state=42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.cluster import DBSCAN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dbscan = DBSCAN(eps=0.05, min_samples=5)\n", "dbscan.fit(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dbscan.labels_[:10]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(dbscan.core_sample_indices_)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dbscan.core_sample_indices_[:10]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dbscan.components_[:3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.unique(dbscan.labels_)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dbscan2 = DBSCAN(eps=0.2)\n", "dbscan2.fit(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_dbscan(dbscan, X, size, show_xlabels=True, show_ylabels=True):\n", " core_mask = np.zeros_like(dbscan.labels_, dtype=bool)\n", " core_mask[dbscan.core_sample_indices_] = True\n", " anomalies_mask = dbscan.labels_ == -1\n", " non_core_mask = ~(core_mask | anomalies_mask)\n", "\n", " cores = dbscan.components_\n", " anomalies = X[anomalies_mask]\n", " non_cores = X[non_core_mask]\n", " \n", " plt.scatter(cores[:, 0], cores[:, 1],\n", " c=dbscan.labels_[core_mask], marker='o', s=size, cmap=\"Paired\")\n", " plt.scatter(cores[:, 0], cores[:, 1], marker='*', s=20, c=dbscan.labels_[core_mask])\n", " plt.scatter(anomalies[:, 0], anomalies[:, 1],\n", " c=\"r\", marker=\"x\", s=100)\n", " plt.scatter(non_cores[:, 0], non_cores[:, 1], c=dbscan.labels_[non_core_mask], marker=\".\")\n", " if show_xlabels:\n", " plt.xlabel(\"$x_1$\", fontsize=14)\n", " else:\n", " plt.tick_params(labelbottom='off')\n", " if show_ylabels:\n", " plt.ylabel(\"$x_2$\", fontsize=14, rotation=0)\n", " else:\n", " plt.tick_params(labelleft='off')\n", " plt.title(\"eps={:.2f}, min_samples={}\".format(dbscan.eps, dbscan.min_samples), fontsize=14)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(9, 3.2))\n", "\n", "plt.subplot(121)\n", "plot_dbscan(dbscan, X, size=100)\n", "\n", "plt.subplot(122)\n", "plot_dbscan(dbscan2, X, size=600, show_ylabels=False)\n", "\n", "save_fig(\"dbscan_diagram\")\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dbscan = dbscan2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "knn = KNeighborsClassifier(n_neighbors=50)\n", "knn.fit(dbscan.components_, dbscan.labels_[dbscan.core_sample_indices_])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_new = np.array([[-0.5, 0], [0, 0.5], [1, -0.1], [2, 1]])\n", "knn.predict(X_new)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "knn.predict_proba(X_new)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(6, 3))\n", "plot_decision_boundaries(knn, X, show_centroids=False)\n", "plt.scatter(X_new[:, 0], X_new[:, 1], c=\"b\", marker=\"+\", s=200, zorder=10)\n", "save_fig(\"cluster_classification_diagram\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_dist, y_pred_idx = knn.kneighbors(X_new, n_neighbors=1)\n", "y_pred = dbscan.labels_[dbscan.core_sample_indices_][y_pred_idx]\n", "y_pred[y_dist > 0.2] = -1\n", "y_pred.ravel()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Spectral Clustering" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.cluster import SpectralClustering" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sc1 = SpectralClustering(n_clusters=2, gamma=100, random_state=42)\n", "sc1.fit(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sc2 = SpectralClustering(n_clusters=2, gamma=1, random_state=42)\n", "sc2.fit(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.percentile(sc1.affinity_matrix_, 95)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_spectral_clustering(sc, X, size, alpha, show_xlabels=True, show_ylabels=True):\n", " plt.scatter(X[:, 0], X[:, 1], marker='o', s=size, c='gray', cmap=\"Paired\", alpha=alpha)\n", " plt.scatter(X[:, 0], X[:, 1], marker='o', s=30, c='w')\n", " plt.scatter(X[:, 0], X[:, 1], marker='.', s=10, c=sc.labels_, cmap=\"Paired\")\n", " \n", " if show_xlabels:\n", " plt.xlabel(\"$x_1$\", fontsize=14)\n", " else:\n", " plt.tick_params(labelbottom='off')\n", " if show_ylabels:\n", " plt.ylabel(\"$x_2$\", fontsize=14, rotation=0)\n", " else:\n", " plt.tick_params(labelleft='off')\n", " plt.title(\"RBF gamma={}\".format(sc.gamma), fontsize=14)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(9, 3.2))\n", "\n", "plt.subplot(121)\n", "plot_spectral_clustering(sc1, X, size=500, alpha=0.1)\n", "\n", "plt.subplot(122)\n", "plot_spectral_clustering(sc2, X, size=4000, alpha=0.01, show_ylabels=False)\n", "\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Gaussian Mixtures" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X1, y1 = make_blobs(n_samples=1000, centers=((4, -4), (0, 0)), random_state=42)\n", "X1 = X1.dot(np.array([[0.374, 0.95], [0.732, 0.598]]))\n", "X2, y2 = make_blobs(n_samples=250, centers=1, random_state=42)\n", "X2 = X2 + [6, -8]\n", "X = np.r_[X1, X2]\n", "y = np.r_[y1, y2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's train a Gaussian mixture model on the previous dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.mixture import GaussianMixture" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm = GaussianMixture(n_components=3, n_init=10, random_state=42)\n", "gm.fit(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's look at the parameters that the EM algorithm estimated:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm.weights_" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm.means_" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm.covariances_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Did the algorithm actually converge?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm.converged_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Yes, good. How many iterations did it take?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm.n_iter_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can now use the model to predict which cluster each instance belongs to (hard clustering) or the probabilities that it came from each cluster. For this, just use `predict()` method or the `predict_proba()` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm.predict(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm.predict_proba(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a generative model, so you can sample new instances from it (and get their labels):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_new, y_new = gm.sample(6)\n", "X_new" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_new" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that they are sampled sequentially from each cluster." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also estimate the log of the _probability density function_ (PDF) at any location using the `score_samples()` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm.score_samples(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check that the PDF integrates to 1 over the whole space. We just take a large square around the clusters, and chop it into a grid of tiny squares, then we compute the approximate probability that the instances will be generated in each tiny square (by multiplying the PDF at one corner of the tiny square by the area of the square), and finally summing all these probabilities). The result is very close to 1:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "resolution = 100\n", "grid = np.arange(-10, 10, 1 / resolution)\n", "xx, yy = np.meshgrid(grid, grid)\n", "X_full = np.vstack([xx.ravel(), yy.ravel()]).T\n", "\n", "pdf = np.exp(gm.score_samples(X_full))\n", "pdf_probas = pdf * (1 / resolution) ** 2\n", "pdf_probas.sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's plot the resulting decision boundaries (dashed lines) and density contours:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib.colors import LogNorm\n", "\n", "def plot_gaussian_mixture(clusterer, X, resolution=1000, show_ylabels=True):\n", " mins = X.min(axis=0) - 0.1\n", " maxs = X.max(axis=0) + 0.1\n", " xx, yy = np.meshgrid(np.linspace(mins[0], maxs[0], resolution),\n", " np.linspace(mins[1], maxs[1], resolution))\n", " Z = -clusterer.score_samples(np.c_[xx.ravel(), yy.ravel()])\n", " Z = Z.reshape(xx.shape)\n", "\n", " plt.contourf(xx, yy, Z,\n", " norm=LogNorm(vmin=1.0, vmax=30.0),\n", " levels=np.logspace(0, 2, 12))\n", " plt.contour(xx, yy, Z,\n", " norm=LogNorm(vmin=1.0, vmax=30.0),\n", " levels=np.logspace(0, 2, 12),\n", " linewidths=1, colors='k')\n", "\n", " Z = clusterer.predict(np.c_[xx.ravel(), yy.ravel()])\n", " Z = Z.reshape(xx.shape)\n", " plt.contour(xx, yy, Z,\n", " linewidths=2, colors='r', linestyles='dashed')\n", " \n", " plt.plot(X[:, 0], X[:, 1], 'k.', markersize=2)\n", " plot_centroids(clusterer.means_, clusterer.weights_)\n", "\n", " plt.xlabel(\"$x_1$\", fontsize=14)\n", " if show_ylabels:\n", " plt.ylabel(\"$x_2$\", fontsize=14, rotation=0)\n", " else:\n", " plt.tick_params(labelleft='off')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 4))\n", "\n", "plot_gaussian_mixture(gm, X)\n", "\n", "save_fig(\"gaussian_mixtures_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can impose constraints on the covariance matrices that the algorithm looks for by setting the `covariance_type` hyperparameter:\n", "* `\"full\"` (default): no constraint, all clusters can take on any ellipsoidal shape of any size.\n", "* `\"tied\"`: all clusters must have the same shape, which can be any ellipsoid (i.e., they all share the same covariance matrix).\n", "* `\"spherical\"`: all clusters must be spherical, but they can have different diameters (i.e., different variances).\n", "* `\"diag\"`: clusters can take on any ellipsoidal shape of any size, but the ellipsoid's axes must be parallel to the axes (i.e., the covariance matrices must be diagonal)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm_full = GaussianMixture(n_components=3, n_init=10, covariance_type=\"full\", random_state=42)\n", "gm_tied = GaussianMixture(n_components=3, n_init=10, covariance_type=\"tied\", random_state=42)\n", "gm_spherical = GaussianMixture(n_components=3, n_init=10, covariance_type=\"spherical\", random_state=42)\n", "gm_diag = GaussianMixture(n_components=3, n_init=10, covariance_type=\"diag\", random_state=42)\n", "gm_full.fit(X)\n", "gm_tied.fit(X)\n", "gm_spherical.fit(X)\n", "gm_diag.fit(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def compare_gaussian_mixtures(gm1, gm2, X):\n", " plt.figure(figsize=(9, 4))\n", "\n", " plt.subplot(121)\n", " plot_gaussian_mixture(gm1, X)\n", " plt.title('covariance_type=\"{}\"'.format(gm1.covariance_type), fontsize=14)\n", "\n", " plt.subplot(122)\n", " plot_gaussian_mixture(gm2, X, show_ylabels=False)\n", " plt.title('covariance_type=\"{}\"'.format(gm2.covariance_type), fontsize=14)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "compare_gaussian_mixtures(gm_tied, gm_spherical, X)\n", "\n", "save_fig(\"covariance_type_diagram\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "compare_gaussian_mixtures(gm_full, gm_diag, X)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Anomaly Detection using Gaussian Mixtures" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Gaussian Mixtures can be used for _anomaly detection_: instances located in low-density regions can be considered anomalies. You must define what density threshold you want to use. For example, in a manufacturing company that tries to detect defective products, the ratio of defective products is usually well-known. Say it is equal to 4%, then you can set the density threshold to be the value that results in having 4% of the instances located in areas below that threshold density:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "densities = gm.score_samples(X)\n", "density_threshold = np.percentile(densities, 4)\n", "anomalies = X[densities < density_threshold]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure()#figsize=(8, 4))\n", "\n", "plot_gaussian_mixture(gm, X)\n", "plt.scatter(anomalies[:, 0], anomalies[:, 1], color='r', marker='*')\n", "#plt.ylim(ymax=5.1)\n", "\n", "save_fig(\"mixture_anomaly_detection_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model selection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We cannot use the inertia or the silhouette score because they both assume that the clusters are spherical. Instead, we can try to find the model that minimizes a theoretical information criterion such as the Bayesian Information Criterion (BIC) or the Akaike Information Criterion (AIC):\n", "\n", "${BIC} = {\\log(m)p - 2\\log({\\hat L})}$\n", "\n", "${AIC} = 2p - 2\\log(\\hat L)$\n", "\n", "* $m$ is the number of instances.\n", "* $p$ is the number of parameters learned by the model.\n", "* $\\hat L$ is the maximized value of the likelihood function of the model. This is the conditional probability of the observed data $\\mathbf{X}$, given the model and its optimized parameters.\n", "\n", "Both BIC and AIC penalize models that have more parameters to learn (e.g., more clusters), and reward models that fit the data well (i.e., models that give a high likelihood to the observed data)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm.bic(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gm.aic(X)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We could compute the BIC manually like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_clusters = 3\n", "n_dims = 2\n", "n_params_for_weights = n_clusters - 1\n", "n_params_for_means = n_clusters * n_dims\n", "n_params_for_covariance = n_clusters * n_dims * (n_dims + 1) // 2\n", "n_params = n_params_for_weights + n_params_for_means + n_params_for_covariance\n", "max_log_likelihood = gm.score(X) * len(X) # log(L^)\n", "bic = np.log(len(X)) * n_params - 2 * max_log_likelihood\n", "aic = 2 * n_params - 2 * max_log_likelihood" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bic, aic" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There's one weight per cluster, but the sum must be equal to 1, so we have one degree of freedom less, hence the -1. Similarly, the degrees of freedom for an $n \\times n$ covariance matrix is not $n^2$, but $1 + 2 + \\dots + n = \\dfrac{n (n+1)}{2}$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's train Gaussian Mixture models with various values of $k$ and measure their BIC:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gms_per_k = [GaussianMixture(n_components=k, n_init=10, random_state=42).fit(X)\n", " for k in range(1, 11)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bics = [model.bic(X) for model in gms_per_k]\n", "aics = [model.aic(X) for model in gms_per_k]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8, 3))\n", "plt.plot(range(1, 11), bics, \"bo-\", label=\"BIC\")\n", "plt.plot(range(1, 11), aics, \"go--\", label=\"AIC\")\n", "plt.xlabel(\"$k$\", fontsize=14)\n", "plt.ylabel(\"Information Criterion\", fontsize=14)\n", "plt.axis([1, 9.5, np.min(aics) - 50, np.max(aics) + 50])\n", "plt.annotate('Minimum',\n", " xy=(3, bics[2]),\n", " xytext=(0.35, 0.6),\n", " textcoords='figure fraction',\n", " fontsize=14,\n", " arrowprops=dict(facecolor='black', shrink=0.1)\n", " )\n", "plt.legend()\n", "save_fig(\"aic_bic_vs_k_diagram\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's search for best combination of values for both the number of clusters and the `covariance_type` hyperparameter:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "min_bic = np.infty\n", "\n", "for k in range(1, 11):\n", " for covariance_type in (\"full\", \"tied\", \"spherical\", \"diag\"):\n", " bic = GaussianMixture(n_components=k, n_init=10,\n", " covariance_type=covariance_type,\n", " random_state=42).fit(X).bic(X)\n", " if bic < min_bic:\n", " min_bic = bic\n", " best_k = k\n", " best_covariance_type = covariance_type" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_k" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_covariance_type" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Using Clustering for classification : MNIST" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load the MNIST dataset and split it into a training set and a test set (take the first 60,000 instances for training, and the remaining 10,000 for testing)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_mldata\n", "mnist = fetch_mldata('MNIST original')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train = mnist['data'][:60000]\n", "y_train = mnist['target'][:60000]\n", "\n", "X_test = mnist['data'][60000:]\n", "y_test = mnist['target'][60000:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train a Random Forest classifier on the dataset and time how long it takes, then evaluate the resulting model on the test set." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "\n", "rnd_clf = RandomForestClassifier(random_state=42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "t0 = time.time()\n", "rnd_clf.fit(X_train, y_train)\n", "t1 = time.time()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Training took {:.2f}s\".format(t1 - t0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score\n", "\n", "y_pred = rnd_clf.predict(X_test)\n", "accuracy_score(y_test, y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use PCA to reduce the dataset's dimensionality, with an explained variance ratio of 95%." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "\n", "pca = PCA(n_components=0.95)\n", "X_train_reduced = pca.fit_transform(X_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train a new Random Forest classifier on the reduced dataset and see how long it takes. Was training much faster?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rnd_clf2 = RandomForestClassifier(random_state=42)\n", "t0 = time.time()\n", "rnd_clf2.fit(X_train_reduced, y_train)\n", "t1 = time.time()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Training took {:.2f}s\".format(t1 - t0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Oh no! Training is actually more than twice slower now! \n", "\n", "Dimensionality reduction does not always lead to faster training time: it depends on the dataset, the model and the training algorithm. \n", "\n", "If you try a softmax classifier instead of a random forest classifier, you will find that training time is reduced by a factor of 3 when using PCA.\n", "\n", "Actually, we will do this in a second, but first let's check the precision of the new random forest classifier." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluate the classifier on the test set: how does it compare to the previous classifier?*" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_test_reduced = pca.transform(X_test)\n", "\n", "y_pred = rnd_clf2.predict(X_test_reduced)\n", "accuracy_score(y_test, y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is common for performance to drop slightly when reducing dimensionality, because we do lose some useful signal in the process. However, the performance drop is rather severe in this case. So PCA really did not help: it slowed down training and reduced performance. :(\n", "\n", "Let's see if it helps when using softmax regression:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression\n", "\n", "log_clf = LogisticRegression(multi_class=\"multinomial\", solver=\"lbfgs\", random_state=42)\n", "t0 = time.time()\n", "log_clf.fit(X_train, y_train)\n", "t1 = time.time()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Training took {:.2f}s\".format(t1 - t0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_pred = log_clf.predict(X_test)\n", "accuracy_score(y_test, y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Okay, so softmax regression takes much longer to train on this dataset than the random forest classifier, plus it performs worse on the test set. But that's not what we are interested in right now, we want to see how much PCA can help softmax regression. Let's train the softmax regression model using the reduced dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "log_clf2 = LogisticRegression(multi_class=\"multinomial\", solver=\"lbfgs\", random_state=42)\n", "t0 = time.time()\n", "log_clf2.fit(X_train_reduced, y_train)\n", "t1 = time.time()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Training took {:.2f}s\".format(t1 - t0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nice! Reducing dimensionality led to a 4× speedup. :) Let's the model's accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_pred = log_clf2.predict(X_test_reduced)\n", "accuracy_score(y_test, y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A very slight drop in performance, which might be a reasonable price to pay for a 4× speedup, depending on the application." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So there you have it: PCA can give you a formidable speedup... but not always!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }