深度学习基础:线性回归


样本 做如下约定:

$$ \tag{1} \mathbf{X} = \begin{bmatrix} \mathbf{x}_1^T \\ \mathbf{x}_2^T \\ \vdots \\ \mathbf{x}_m^T \end{bmatrix} \in \mathbb{R}^{m \times n} $$ $$ \begin{matrix} \tag{2} \mathbf{x}_i = \begin{bmatrix} x_{i1} \\ x_{i2} \\ \vdots \\ x_{in} \end{bmatrix} \in \mathbb{R}^n, & i \in \{1, \cdots, m\} \end{matrix} $$

模型定义

若有权重 $\mathbf{w} = (w_1, w_2, \cdots, w_n)^T \in \mathbb{R}^n$,偏置 $b \in \mathbb{R}$,则线性模型可以表示为:

$$ \tag{3} \hat{\mathbf{y}} = \mathbf{Xw} + b \in \mathbb{R}^m $$

展开公式 (3),即:

$$ \begin{cases} \tag{4} \hat{\mathbf{y}} = (\hat{y}_1, \hat{y}_2, \cdots, \hat{y}_m)^T\\ \hat{y}_i = \mathbf{x}_i^T \mathbf{w} + b = \langle \mathbf{x}_i, \mathbf{w} \rangle + b,&i \in \{1, \cdots, m\} \end{cases} $$

损失函数

已知样本 $(\mathbf{x}_i, y_i) _{i=1}^{m}$,且 $\mathbf{x}_i$ 的预测值为 $\hat{y_i}$,则定义可单个样本是损失函数:

$$ \tag{5} l^{(i)}(\mathbf{w}, b) = \frac 1 2 (\hat{y}_i - y_i)^2, i \in \{1, \cdots, m\} $$

总损失函数定义为:

$$ \tag{6} L(\mathbf{w}, b) = {\frac 1 m} \sum_{i=1}^m l^{(i)}(\mathbf{w}, b) = {\frac 1 {2m}} \lVert \mathbf{Xw} + b - \mathbf{y} \rVert ^2 $$

在训练模型时,我们希望寻找一组参数 $(\mathbf{w}^*, b^*)$,这组参数能最小化在所有训练样本上的总损失。如下式:

$$ \tag{7} \mathbf{w}^{\ast}, b^{\ast} = \argmin_{\mathbf{w}^{\ast}, b^{\ast}} L(\mathbf{w}, b) $$

可以求得解析解:

将 $\mathbf{w}$ 与 $b$ 合并为 $\overline{\mathbf{w}}$,$\overline{\mathbf{X}} = (\mathbf{X}, \mathbf{1})$,则公式 (6),可以写作:

$$ \tag{8} L(\mathbf{w}, b) = {\frac 1 {2m}} \lVert \overline{\mathbf{X}} \overline{\mathbf{w}} - \mathbf{y} \rVert ^2 $$

这很容易求得解析解:

$$ \tag{9} \overline{\mathbf{w}}^{\ast} = (\overline{\mathbf{X}}^T \overline{\mathbf{X}})^{-1} \overline{\mathbf{X}}^T \mathbf{y} $$

对于实际问题,往往模型很复杂很难求得解析解,大都仅仅求得其近似解。

梯度下降

由计算梯度得:

$$ \tag{10} \nabla_{\overline{\mathbf{w}}} L = {\cfrac 1 m} \overline{\mathbf{X}}^T (\overline{\mathbf{X}} \overline{\mathbf{w}} - y) $$

所以,参数更新:

$$ \tag{11} \begin{cases} \mathbf{w} \leftarrow \mathbf{w} - {\cfrac \eta m} \mathbf{X}^T (\mathbf{Xw} + b - \mathbf{y}) \\ b \leftarrow b - {\frac \eta m} \mathbf{1}^T (\mathbf{Xw} + b - \mathbf{y}) \end{cases} $$

其中 $\eta$ 表示学习率。


文章作者: xinetzone
版权声明: 本博客所有文章除特别声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 xinetzone !
评论
  目录