文章目录

线性回归-标准方程法示例(python原生实现)

1. 导包

import numpy as np
import matplotlib.pyplot as plt
import random

2. 原始数据生成与展示

# 原始数据(因为是随机的,所以生成的数据可能会不同)
x_data = np.linspace(1,10,100)
y_data = np.array([j * 1.5 + random.gauss(0,0.9) for j in x_data])
# 作图
plt.scatter(x_data, y_data)
plt.show()

线性回归-标准方程法示例(python原生实现)-LMLPHP

3. 数据预处理

# 数据预处理
x_data = x_data[:, np.newaxis]
y_data = y_data[:, np.newaxis]
# 添加偏置
X_data = np.concatenate((np.ones((100, 1)), x_data), axis = 1)
print(X_data)

4. 定义利用标准方程法求weights的方法

def calc_weights(X_data, y_data):
    """
    标准方程法求weights
    算法: weights = (X的转置矩阵 * X矩阵)的逆矩阵 * X的转置矩阵 * Y矩阵
    :param x_data: 特征数据
    :param y_data: 标签数据
    """

    x_mat = np.mat(X_data)
    y_mat = np.mat(y_data)
    xT_x = x_mat.T * x_mat
    if np.linalg.det(xT_x) == 0:
        print("x_mat为不可逆矩阵,不能使用标准方程法求解")
        return
    weights = xT_x.I * x_mat.T * y_mat
    return weights

5. 求出weights

weights = calc_weights(X_data, y_data)
print(weights)

6. 结果展示

x_test = np.array([2, 5, 8, 9])[:, np.newaxis]
predict_result = weights[0] + x_test * weights[1]
# 原始数据
plt.plot(x_data, y_data, "b.")
# 预测数据
plt.plot(x_test, predict_result, "r")
plt.show()

线性回归-标准方程法示例(python原生实现)-LMLPHP

02-20 18:55