Skip to content

Instantly share code, notes, and snippets.

@borgwang
Last active April 12, 2024 09:28
Show Gist options
  • Select an option

  • Save borgwang/4313e9375ef233c3b812f9f80f1af2bb to your computer and use it in GitHub Desktop.

Select an option

Save borgwang/4313e9375ef233c3b812f9f80f1af2bb to your computer and use it in GitHub Desktop.
quantile_regression
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Quantile Regression Tutorial"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 特点\n",
"\n",
"- 一般回归模型对自变量与因变量均值进行回归,假定了预测的误差的方差是固定的(实际中有时并不是)\n",
"- 分位数回归对自变量与因变量的不同条件分位数进行回归\n",
"- 可以得到一个置信区间,通过不同的分位了解预测的方差\n",
"\n",
"#### 基本原理\n",
"\n",
"假定因变量真实值为 $ y=(y_1, y_2, ..., y_n) $,目标值为 $ \\hat{y} = (\\hat{y_1}, \\hat{y_2}, ..., \\hat{y_n}) $\n",
"\n",
"分位数回归的损失函数是:\n",
"\n",
"$$ L_r(y, \\hat{y}) = (1-r)\\frac{1}{N}\\sum^{i}_{\\hat{y_i}\\geq y_i}(\\hat{y_i}-y_i) + r\\frac{1}{N}\\sum^{i}_{\\hat{y_i}< y_i}(y_i-\\hat{y_i}) $$\n",
"\n",
"其中 r 是分位数系数,这个损失函数是平均绝对误差的拓展,当 r=0.5 时退化成 LAD regression(Least absolute deviations)。\n",
"\n",
"Intuition:当 r > 0.5 时,损失函数对预测值偏小的数据惩罚更大(使 fitting 到曲线上移,值更大);当 r > 0.5 时,损失函数对预测值偏小的数据惩罚更大。即 r 控制了模型对于该分位系数下两类数据的不同惩罚程度。\n",
"\n",
"#### 梯度\n",
"\n",
"使用最简单的线性模型 y = wx + b ,上述损失函数对应 w,b 的梯度为\n",
"\n",
"$$ \\frac{dL_{w,b}}{dw} = (1-r)\\frac{1}{N}\\sum^{i}_{wx_i+b\\geq y_i}x_i + r\\frac{1}{N}\\sum^{i}_{wx_i+b< y_i}(-x_i) $$\n",
"\n",
"$$ \\frac{dL_{w,b}}{db} = (1-r)\\frac{1}{N}\\sum^{i}_{wx_i+b\\geq y_i}\\times1 + r\\frac{1}{N}\\sum^{i}_{wx_i+b< y_i}\\times(-1) $$\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7fe26850d470>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# generate data with varying variances\n",
"X = np.random.uniform(0, 10, (100, 1))\n",
"Y = 2 * X + 10 + np.random.normal(0, 0.8 * X)\n",
"\n",
"plt.scatter(X, Y)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fe2906a8208>]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X+QXGWd7/H3N8Mow49lgEQNE9jJ3XUju6AJO1C4qVU3WDe4RBlxRdT1osWKlr/AtWISCoWt0kqAVcFSuTcrLFCyGAriQAUx6xIQN5ask0wgshHxSsAMuTAKIwJjnEy+94/unkxmzuk+3XN+9+dVlcrM6dN9nk5mvufp7/N9nsfcHRERKb45WTdARETioYAuIlISCugiIiWhgC4iUhIK6CIiJaGALiJSEgroIiIloYAuIlISCugiIiVxWJoXmzt3rvf29qZ5SRGRwtu2bduv3X1eo/NSDei9vb0MDg6meUkRkcIzsyejnKeUi4hISSigi4iUhAK6iEhJKKCLiJSEArqISEmkWuUiIpKkgaFhrtn8GE+PjnFCdxcrly+if0lP1s1KjQK6iJTCwNAwazbuZGx8AoDh0THWbNwJ0DZBXSkXESmFazY/NhnMa8bGJ7hm82MZtSh9CugiUgpPj441dbyMFNBFpBRO6O5q6ngZKaCLSCmsXL6Irs6OQ451dXawcvmijFqUPg2Kikgp1AY+VeUiIlIC/Ut62iqAT6eUi4hISSigi4iUROSAbmYdZjZkZpuq3y80s4fM7HEz22Bmr0iumSIi0kgzPfRLgF1Tvr8K+Iq7vxZ4HrgozoaJiEhzIgV0M1sAnAN8s/q9AcuAO6qn3Az0J9FAERGJJmoP/Vrgs8CB6vfHA6Puvr/6/R6gfYeWRURyoGFAN7MVwLPuvm3q4YBTPeT5F5vZoJkNjoyMtNhMERFpJEoPfSnwDjPbDXybSqrlWqDbzGp17AuAp4Oe7O7r3b3P3fvmzWu4abWIiLSoYUB39zXuvsDde4ELgC3u/n7gfuDvqqddCNyVWCtFRKSh2dShrwL+0cx+QSWnfkM8TRIRkVY0NfXf3R8AHqh+/UvgjPibJCIirdBMURGRklBAFxEpCQV0EZGSUEAXESkJrYcuIpKQgaHhVDfcUEAXEUnAwNAwazbuZGx8AoDh0THWbNwJkFhQV8pFRCQB12x+bDKY14yNT3DN5scSu6YCuohIAp4eHWvqeBwU0EVEEnBCd1dTx+OggC4ikoCVyxfR1dlxyLGuzg5WLl+U2DU1KCoikoDawKeqXERESqB/SU+iAXw6pVxEREpCAV1EpCSUchGRtpD2rM0sKKCLSOllMWszC0q5iEjpZTFrMwvqoYtI6aU5azPL1I566CJSemnN2qyldoZHx3AOpnYGhoZjvU4YBXQRKb20Zm1mndpRykVESi+tWZtZLMg1VcOAbmaHAw8Cr6yef4e7X2FmNwFvBn5bPfWD7r4jqYaKiMxGGrM2T+juYjggeCe5INdUUXro+4Bl7v6imXUC/2lm91YfW+nudyTXPBGRiiLUka9cvuiQ8khIfkGuqRoGdHd34MXqt53VP55ko0REpipKHXkWC3JNZZV43eAksw5gG/CnwNfdfVU15fJGKj34+4DV7r6v3uv09fX54ODgrBstIu1l6botgamMnu4utq5elkGL0mVm29y9r9F5kapc3H3C3RcDC4AzzOwUYA3wOuB04DhgVUhDLjazQTMbHBkZifwGRERqsh5sLIqmyhbdfRR4ADjb3fd6xT7gX4EzQp6z3t373L1v3rx5s26wiLSfLHb/KaKGAd3M5plZd/XrLuCtwM/MbH71mAH9wE+TbKiItK8sdv8poihVLvOBm6t59DnA7e6+ycy2mNk8wIAdwEcTbKeItLGsBxuLItKgaFw0KCoi0rxYB0VFRCT/FNBFREpCAV1EpCQU0EVESkIBXUSkJLR8rohIzLJaSEwBXUQkRlkuJKaUi4hIjLLctUg9dBHJXBHWOo8qy4XE1EMXkUxlvbFy3LJcSEwBXUQylfXGynHLciExpVxEJFNlW+s8y4XEFNBFJFNZb6ychDQ2pA6ilIuIZEprncdHPXQRyZTWOo+PArqIZC6rFEXZKKCLSObKVIeeJQV0EclUllPly0aDoiKSqbLVoWdJAV1EMlW2OvQsKaCLSKaynCpfNg0Dupkdbmb/ZWYPm9mjZvZP1eMLzewhM3vczDaY2SuSb66IlI3q0OMTpYe+D1jm7m8AFgNnm9mZwFXAV9z9tcDzwEXJNVNEyqp/SQ9rzzuVnu4uDOjp7mLteadqQLQFDatc3N2BF6vfdlb/OLAMeF/1+M3AlcD18TdRRMpOdejxiJRDN7MOM9sBPAt8H/i/wKi776+esgcI/N8ws4vNbNDMBkdGRuJos4iIBIgU0N19wt0XAwuAM4CTg04Lee56d+9z97558+a13lIRKaWBoWGWrtvCwtX3sHTdlsKugz7DxAR8/vNgBmeemcolm5pY5O6jZvYAcCbQbWaHVXvpC4CnE2ifiMQsT7MySzmp6BvfgI9//NBjc9IpKIxS5TLPzLqrX3cBbwV2AfcDf1c97ULgrqQaKSLxyNvuQKWZVPTcc5WeuNmhwfz1r6889qMfpdKMKLeN+cD9ZvYI8BPg++6+CVgF/KOZ/QI4HrghuWaKSBzyFkALP6noYx+rBPHjjz/0+Kc/De7w8MNw7LGpNSdKlcsjwJKA47+kkk8XkYLIWwAt5OYWP/85LAqpkR8dhWOOSbc9U2imqEgbyduszMJMKnKH00+v9ManB/NvfrPyuHumwRy02qJIYvI0+FizcvmiQwYhIb4A2sr7zf3mFjfcAP/wDzOPd3bCyy/DYfkKoflqjUhJJFW9MdubRFIBdDbvN3eTil5+GY48MvixBx6AN7851eY0wyoTQdPR19fng4ODqV1PJCtL120JzA33dHexdfWyll5zetCESu/6XX/Zw/0/G8m0h5vE+01dTw88HVJ9nWKcDGJm29y9r9F5yqGLJCCJwcewCpVbf/xU5mWIeRtsjezxxw+WG04P5t/5zsHceEEooIskIInBx7DgOD3cZFGGmLfB1oZqQfzP/mzmY7Ug3t+ffrtmSQFdJAFJVG80ExzT7hkXolrl6qsPBvLpHn98MpAXeSkCDYqKJCCJwcegChUjeBGltHvGua5WCQrgNdPSKUVfikCDoiIFMr3K5W9eN487tw3PGCht+/XEu7vht78NfuyFF+DoowMfyuvgbtRBUfXQpdTyWAs+G0Elfn1/fFyp3mPLRkfDp9mfeio88kjDlyjs4G6VArqUVtE/PkeVuzruiGK72TaRUmmkkEsRTKFBUSmtvC1EJQc1s+pj4CDlli3hA5wXXdRyuWEhBnfrUA9dSqvoH5/LrN7NdmovffqnrK1rzoI1IS8aw3hgrgd3I1BAl9Iq+sfnMot6s71m82Os3biO/v/+QfALfe97sHx5rG0ragoLFNClQJrNuSa5EFUWyjTAe0xXJ6Nj4zOOH3KzNWNryPMXrtrEE+vOSaZxBaaALoXQygBn0T8+T1WmAd6BoWFe+sP+Gcc751jdlMppn7yV546oLE/bo09ZgRTQpRCi5lynK/LH56laff95dM3mxxifOJjvnvvS8wx+7QOh5598+b2l+ZSVNAV0KYR2H+As0/uvtXn3VSvCTzpwYLKCZW2JUk1JU0CXQmj3Ac7SvP9rr+WJqz4d/nhApUqWn7KKNm6hOnQphKLXB89W0d7/9NrxyZrxT88M5r2rNnHy5fcysH1PBi0N10ytfF407KGb2YnALcBrgAPAene/zsyuBD4MjFRPvczdv5tUQ6W9lWmAsxVFev+1QLjrC28LPefnH/wYH1r0Lp4eHaMnp++liOMWDRfnMrP5wHx3325mRwPbgH7gfOBFd//nqBfT4lwixfsY37QYp+JnaeHqewJXsjRIvWQytsW53H0vsLf69e/MbBdQop8+kfTkufxwVjeaOkH8bz78f3jiuJ5KIIynqako4rhFUzl0M+sFlgAPVQ99wsweMbMbzSxkmTMRqcnr+jIt5YuHh8PXU6GSG+9dtYknjqvcFPIcCIMUbdwCmgjoZnYUcCdwqbu/AFwP/AmwmEoP/kshz7vYzAbNbHBkZCToFJG2kdfyw7AbzWduf3hmUK8F8QULZr7QgQMMbN/DyZffe8jhvAfCIP1Lelh73qn0dHdhVCYz5X2d+UgbXJhZJ7AJ2OzuXw54vBfY5O6n1Hsd5dCl3eV1A4WwfDFUgvHdT93Fa2/53+EvELDzT6nHCVIWWw7dzAy4Adg1NZib2fxqfh3gncBPW22sSLtotL5MVoEwLF9cd/JPnc5gWWboFk2UiUVLgQ8AO81sR/XYZcB7zWwxlS0NdwMfSaSFIiVSr/wwywHTqTeaukH8r/8aHnww0bYUSd4+iWhPUZGcyDwdU6dSZena+zJNCeXR9BswJLefq/YUFSmYTAZM6wTxd79vHT858ZRKkCrYgGYa8jjxSAFdJCdSq3veswdOPDH04YHteybTCHmdxZkHeaxYUkAXyYnEN+SoN4NzYgLmVKqY+8l+klMR5HHikRbnEsmJROqeP/7xupN/JjdTntM4FARu1tzG8jjxSD10kRyJrdwv5vVU8rxkQVbyuGCaArpIxmIrfasXxLu74fnnW25jHgcA8yBv9fYK6CIZiqXnm8LqhnkcAExK3mrLm6EcukiGWl6sq5YXDwrm99xzMDcek7CBvqItuNVIETe1mEoBXUqjiIN2TfV8G6xuOBnE//ZvY2xhRR4HAJOQ19Uwo1LKRUoh60G7Vj+mRyp9q5dS2b8fOjrCH49JHgcAk1D01JICupRCloN2s7mZhNWef2vbTWBnhT8xg51/8jYAmIQ81pY3QwFdSiHLntWVdz/a8s1kes/3iRZXN5R4JD65K2EK6FIKWfWsBoaGGR0bD3ws6s2k/7QF9Nc7IUeBvMgVIFEUPbWkgC6lkFXPqt5gWcObScE2U856nCItRU4tqcpFSiGr7cLq9cIDbyb1yg3/7d9iLzeMU9ErQNqBeuhSCFE+6mfRswpL9Rx7ROfBtjzzDLzmNeEvktMAPl3RK0DagXroknt5mOwRVuMeVp99xdv/4mBPPCiY79uX6954kHaZXFRk6qFL7mW9jkiU3HHt08MXf3QL7/vh7fCFkBebZQDPclCy6BUg7UABXTLXKEjF+VG/lYDY6IbSv6SH/tMWhL9ATL3wrAcli14B0g4U0CVTUYJUXCWJrQbEsBvH1jVnwZo6F4w5nZL1JxUodgVIO1AOXTIVpXIirnVEWq3SmH7j2H3VCnaHTQCq5cUTyI3HPShZxLVvpL6GPXQzOxG4BXgNcABY7+7XmdlxwAagF9gNnO/urS+4LG0pSpCK66N+2LWGR8dYum5L6GuuXL6ofkpl/Xr48Iebaksr4pw8lXX6RpIRJeWyH/iMu283s6OBbWb2feCDwH3uvs7MVgOrgVXJNVXKKGqQiuOjfti1ICSg/eY3MHdu+CzOlCtU4hyUzEP6RuLXMOXi7nvdfXv1698Bu4Ae4Fzg5uppN0P92csiQeJclrVRCiHoWlNNpl9q5YZz5wacNJZaueH09wPENnlKNeXl1NSgqJn1AkuAh4BXu/teqAR9M3tV7K2T0ms2nRJWpRIlhTD1WtN76p/9wU187Md3hDc05d542PtZe96pbF29bNavX/RVBSWYecQfVDM7CvgB8EV332hmo+7ePeXx59392IDnXQxcDHDSSSf95ZNPPhlPy6XtTA9yUOnNrz3v1MAgDZVebFAAXLpuC8OjY+GDm5DppJ9a+6YLez/NqvdvqZRL/pjZNnfva3RepB66mXUCdwK3uvvG6uFnzGx+tXc+H3g26Lnuvh5YD9DX11ecaXGSO/XyvvUGPAeGhg8NUmZsrXehHMzeTDoloprycopS5WLADcAud//ylIfuBi4E1lX/viuRFkoptTLBp16QqzfgOZl6qVOpsnTtfbkJaANDw8wxYyLgxhJnSkQ15eUTpYe+FPgAsNPMdlSPXUYlkN9uZhcBTwHvTqaJUjatlszVy/sGVYAAB1MqQVPxr70WLrkEoH6PPUW1f5ugYN5O0+zjWOKg7Gu3B2kY0N39P4GwhZvr7JElEqzVkrl6ZXu15126YQd/9PsXeeS6C8IbkIOUSpigfxuADrO2yW/HUSPfrnX2mikqqWs1P9xozfP+0xaw+6oVgcH8rH/aVIjVDcP+DQ64lzoQTRXHuuvtuna71nKR1LVSMjf94/NX3rO4EuA+9zk4LWxpQ+hdtalSvfH2U2Npe9KKWk4YZ3ojjgHhdq2zVw9dUtfsZKKg9dD7T1tQmfzzhZnBfGD7HpauvY+FqzaltnNRXOKcaJWWuNerj2Pd9XZdu109dElMWK+t2ZK52sfnujXjHFqpktY66a30Sus9r4jlhHEvIxDHEgftunZ75IlFcejr6/PBwcHUrifZiXXiSp3NlE++/N5MJse0+v7KNqFnYGiYSzfsCHzMgCfWndPy66rK5aCoE4sU0CU2U3+BwuqoI890rBPEr37T/+IbbzyfjtleYxZancmZ9AzQNAXdnKYq4nvKq1hnioo0Mv2XOyjQQoNBqZdegqOOCn24d9Wmya+7OjtCA0kaA1+tDrqVabAurMQSoHOO8fIf9rNw9T2F7x0XiQZFJRb1frmnChyUqq1uGBTMX3gB3BnYvmdGuWJPhgNfrQ66lWmwru5NyOD5l8cz29S7XSmgSyzCpt1Pdcig1Pr1BwN5kFrN+NFHA5XBwq2rl/HEunPYunoZ/Ut6Mq0IafXaRaxiCRN2E+owY3zi0E9o7VADngdKuUgswvLZUBkcm/zYHeNmyllWhLR67SJWsYQJqyTJMhXW7jQoWnB5GcnvXX1P6GONyg3zPntTwgX9/DW7lLE0pkHRNpCn9Sp6AmY45nWtcYlPWM1/O9aA54Fy6AWWp/Uqarnh3VetmPwzw9VXF2I9FZmdRmvuSHLUQy+weiVwqaZi9u2j/7QFiW+mnJf0kjSmtdazoYBeYGELOR3T1ZlOKqbO5B9GR+GYY2K7VJ7SSyJ5pZRLgYWVwJmRXCpm48Zo5YYxBnPIV3pJJK8U0AssLFc5+vJ44PmzKhurBfF3vWvmY7UgnmBuvN6eoQtX38PSdVs0cUXanlIuBReUqwwrG2t6NuKCBTBcJ0imOLhZb8/QqbMRQSkYaV/qoZfQrGcj1nrjQcE8hd54kKD3NJ1SMNLu1EMvoZZmI9Yb4Lz+evjoRyNfP4lqlOnvKex2otmI0s4U0EsqUtnY+Di84hXhj7fQC0+qGmX6TeKlffsZHZs5VlDERa5E4tIwoJvZjcAK4Fl3P6V67Ergw8BI9bTL3P27STVSYlavN/7889DdHellwqZ9B1WjfOb2h4HWgnrQTaKzw+icY4wfOHjT0WxEaXdRcug3AWcHHP+Kuy+u/lEwz7vNm6OVGzYRzIP2kQwbuJxwb3kJ1aCbxPiEc9Thh2k2osgUDXvo7v6gmfUm3xSpp+W8dL3e+CwGNsN64vVWXWx1n8mwvPjoy+MMff5/NvVaImU2myqXT5jZI2Z2o5kdG3aSmV1sZoNmNjgyMhJ2mtTR9K7qb3pTtN74LIQF2Qn3utUoUdZNn65Mm0KIJKnVgH498CfAYmAv8KWwE919vbv3uXvfvHnzWrxce4s8S7IWxH/4w5kvEnO5YVgwraU+OkJuJgZNp13KtCmESJJaCuju/oy7T7j7AeBfgDPibVY+DQwNs3TdltRnJtbdh7IWxAMC6Jff/kkGtu9JpGa8XpDtX9LDl85/A0Eh3aHpWnGt3icSTUtli2Y23933Vr99J/DT+JqUT1kuDjV9luScAxP88ppzQ88/ZDPlhNrYqNa9f0kPl27YEfjcVmrFtXqfSGNRyhZvA94CzDWzPcAVwFvMbDGVDtdu4CMJtjEX6qU9kg40ta2+dn3hbeEn/eY3LF2/Y0aOOsk2NgqyQZtegHLfIkmJUuXy3oDDNyTQllyrm/ZI0uAg/aefHmmt8czaGCJsz0nlvkWSobVcIgrrVXYf0ZnMBWt58dNPn/HQZF58Wm48b9Ugyn2LpEtT/yNauXwRK+94mPGJQ4Poi7/fz8DQcDxB6iMfgfXrAx8aPfwoFl/ybSA8L57HHrFy3/miXZ/KTQE9ov4lPVy28ZEZAX38gM8+R11n8s/StfdFzou3tCiXtA3t+lR+CugRDQwN8/L4gcDHWspR15vB+a1vwfvfX3nt1fcEnjJc3Tc0KKjrl1OCZDmwL+lQQI+oXu105Bz1gQPQUWdN74B68XobO6h3Jc3I26C5xE+DohHV+6FvmKOuDXAGBfPnnqs7g7Pexg5ZbeiQ1QQrmZ28DZpL/BTQIwqtcunqDO4h79wZbT2VY0OXwQEOVoqESbt31fS6MpIbWkKh/BTQIwr7ZVjxhvmH9FYng/jrXz/zRVpcT6V/SQ89OeldRV5XRnJHZaTlpxx6RP1Lehh88jlue+hXTLjTYcZpJx3DnduGWfXdr/PB7ZuCn3j88fDrX8/6+nkpSVQettg0aF5ubRvQm63HHRga5s5tw5NrfU+4c+vFfxV6fu+qTXR1dlR6QC1cb7q8lCSGDdIqDyuSvbYM6GH1uINPPsf9PxsJDJi1VMMtGz7Hm3YPBb7uZ8/+FLe/4eCGC1NTEXHU/+ahd5WXTwoiMpN5Akurhunr6/PBwcHUrhdm6botgb1Mg0N2k5/sYS8+AeaEDzdMXd0w6DXDerU93V1sXb2siZbng2YbiqTLzLa5e1/D89oxoC9cfQ9R3vXuq1aEPnbqpRv43SuPbPgaPd1dPF2tCAl7XAFRROqJGtDbMuVSb7JO99gL7Pjq+wIfG+06msWfui3ydWqpiGs2PxZ6PU2/jk6fDETqK0VAb/YXPSgPXK83XiszfGBomI7bHw7dBBkOpm2m97ynX28qTb9uTOuQiDRW+IAe9It+6YYdXHn3o1z5jr8I/GWvHfuP6zfwtX/5TODrfvWN7+H6sz44WaUy9XnTg3NYEJ9+vXo9dZX91ad1SEQaK3xAD/pFBxgdGw/vwZnRD4GbRixde99kT39tg+DczEf/WoVK2ICsyv7qU/27SGOFD+j1fqEP6cGtWgVXXx184vbtsGQJAFsjXHM25YMq+2uN6t9FGitsQK/lzetWq7izdc1ZsCb88Vavm8QEIQ36hdONUKSxQgb0ywd2cuuPnwoN5juuu4Du378Y/OCLL8KRjcsNg8Q1MBfUw9egX315mSkrkmcNA7qZ3QisAJ5191Oqx44DNgC9wG7gfHd/PrlmVgwMDXPl3Y8yOjY+47Gj973EzmvfE/i8J151Ess+9I1KEPj5KP1LWgvoSQ7MadCvsTzMlBXJsyg99JuArwG3TDm2GrjP3deZ2erq96vib97BFMfw6NiMmZwA9974CU4e2R343Mu/8wh3bhuOrdeb5MCcBv1EZLYaBnR3f9DMeqcdPhd4S/Xrm4EHSCCgT09D1IL5MWO/4+GvvjfwOWuWf4KB089h7Xmncn/Mvd4kB+byOuinvL5IcbS6Hvqr3X0vQPXvV8XXpIOmpyGW/eK/2H3VisBg3rtqE72rNvHgW945ucZz3L3eJDcIyOPmA9rMQqRYEh8UNbOLgYsBTjrppKaeWwu8YT3y11/ybV44/CgM+PszT+IL/Yfu7BN3rzfJgbk8Dvopry9SLK0G9GfMbL677zWz+cCzYSe6+3pgPVQW52rmIrWAfPj+fQD89pVH8r4Lvsijr/nTyXOOPaKTK94ePCM0iVK3JAfm8jbop7y+SLG0GtDvBi4E1lX/viu2Fk1RC8jPHD13cona2g6dUVYpzGOvt0jymtcXkWBRyhZvozIAOtfM9gBXUAnkt5vZRcBTwLuTaFwcATlvvd4i0WQekWJpy/XQJTpVuYhkT+uhSyz0CUekOBTQRaQtlfHTpwK6iLSdsq6d1OrEIhGRwqo3x6LIFNBFpO2UdY6FArqItJ2wuRRFn2OhgC4ibSePayfFQYOiItJ2yjqLXAFdRNpSGedYKOUiIlISCugiIiWhgC4iUhIK6CIiJaGALiJSEqkun2tmI8CTLT59LvDrGJtTFHrf7aMd3zPofUfxx+4+r9FJqQb02TCzwSjrAZeN3nf7aMf3DHrfcb6mUi4iIiWhgC4iUhJFCujrs25ARvS+20c7vmfQ+45NYXLoIiJSX5F66CIiUkfuA7qZnW1mj5nZL8xsddbtSYOZnWhm95vZLjN71MwuybpNaTKzDjMbMrNNWbclLWbWbWZ3mNnPqv/vb8y6TWkws09Xf8Z/ama3mdnhWbcpCWZ2o5k9a2Y/nXLsODP7vpk9Xv372NleJ9cB3cw6gK8DbwP+HHivmf15tq1KxX7gM+5+MnAm8PE2ed81lwC7sm5Eyq4DvufurwPeQBu8fzPrAT4F9Ln7KUAHcEG2rUrMTcDZ046tBu5z99cC91W/n5VcB3TgDOAX7v5Ld/8D8G3g3IzblDh33+vu26tf/47KL3e51vkMYWYLgHOAb2bdlrSY2R8BbwJuAHD3P7j7aLatSs1hQJeZHQYcATydcXsS4e4PAs9NO3wucHP165uB/tleJ+8BvQf41ZTv99Amga3GzHqBJcBD2bYkNdcCnwUOZN2QFP0PYAT412qq6ZtmdmTWjUqauw8D/ww8BewFfuvu/55tq1L1anffC5VOHPCq2b5g3gO6BRxrm7IcMzsKuBO41N1fyLo9STOzFcCz7r4t67ak7DDgNOB6d18CvEQMH7/zrpozPhdYCJwAHGlmf59tq4ot7wF9D3DilO8XUNKPZNOZWSeVYH6ru2/Muj0pWQq8w8x2U0mvLTOzb2XbpFTsAfa4e+1T2B1UAnzZvRV4wt1H3H0c2Aj8VcZtStMzZjYfoPr3s7N9wbwH9J8ArzWzhWb2CioDJndn3KbEmZlRyafucvcvZ92etLj7Gndf4O69VP6vt7h76Xts7v7/gF+ZWW2H4rOA/86wSWl5CjjTzI6o/syfRRsMBk9xN3Bh9esLgbtm+4K53lPU3feb2SeAzVRGwG9090czblYalgIfAHaa2Y7qscvc/bsZtkmS9Ung1mrH5ZfAhzJp2XBJAAAAZUlEQVRuT+Lc/SEzuwPYTqWya4iSzho1s9uAtwBzzWwPcAWwDrjdzC6icnN796yvo5miIiLlkPeUi4iIRKSALiJSEgroIiIloYAuIlISCugiIiWhgC4iUhIK6CIiJaGALiJSEv8fVa/kGbcv+t8AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# normal linear regression\n",
"from sklearn.linear_model import LinearRegression\n",
"lr = LinearRegression()\n",
"lr.fit(X, Y)\n",
"plt.scatter(X, Y)\n",
"plt.plot(X, lr.predict(X), color='red')"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"# 一个无敌简易版的分位数回归\n",
"class QuantileRegression:\n",
" def __init__(self, r=0.5):\n",
" self.r = r\n",
" self.batch = 8\n",
" self.w, self.b = 0.0, 0.0\n",
" self.lr = 0.1\n",
" \n",
" def fit(self, X, y):\n",
" for n in range(5000):\n",
" random_idx = np.random.randint(0, len(X), self.batch)\n",
" train_X, train_y = X[random_idx], Y[random_idx]\n",
" pred_y = self.w * train_X + self.b\n",
" grad_w = np.mean((pred_y > train_y) * (1 - self.r) * train_X) + np.mean((pred_y < train_y) * self.r * - train_X)\n",
" grad_b = np.mean((pred_y > train_y) * (1 - self.r)) + np.mean((pred_y < train_y) * self.r * -1)\n",
" self.w -= self.lr * grad_w\n",
" self.b -= self.lr * grad_b \n",
" print('training done! w: %.4f b: %.4f' % (self.w, self.b))\n",
" \n",
" def predict(self, X):\n",
" return self.w * X + self.b\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training done! w: 0.9305 b: 9.5375\n",
"training done! w: 2.1781 b: 10.0625\n",
"training done! w: 2.8371 b: 10.1125\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fe23efff940>"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(X, Y)\n",
"for r in [0.1, 0.5, 0.9]:\n",
" qr = QuantileRegression(r)\n",
" qr.fit(X, Y)\n",
" plt.plot(X, qr.predict(X), label='quantile'+str(r), alpha=0.6)\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 和 StatsModel 的包对比一下\n",
"import pandas as pd\n",
"import statsmodels.formula.api as smf\n",
"\n",
"data = pd.DataFrame(np.hstack([X, Y]), columns=['x', 'y'])\n",
"mod = smf.quantreg('y ~ x', data)\n",
"\n",
"plt.scatter(X, Y)\n",
"for r in [0.1, 0.5, 0.9]:\n",
" res = mod.fit(q=r)\n",
" y_ = res.params['x'] * X + res.params['Intercept']\n",
" plt.plot(X, y_, label='quantile'+str(r), alpha=0.6)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Ref\n",
"- [这份 tutorial 的前几页](https://support.sas.com/resources/papers/proceedings17/SAS0525-2017.pdf)\n",
"- [StatsModels Example](https://www.statsmodels.org/dev/examples/notebooks/generated/quantile_regression.html)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment