PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析
2023-09-14 09:04:43 时间
PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析
目录
输出结果
核心代码
#PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析
from sklearn.datasets import make_regression
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
sns.set()
x_train, y_train, W_target = make_regression(n_samples=100, n_features=1, noise=10, coef = True)
df = pd.DataFrame(data = {'X':x_train.ravel(), 'Y':y_train.ravel()})
sns.lmplot(x='X', y='Y', data=df, fit_reg=True)
plt.show()
x_torch = torch.FloatTensor(x_train)
y_torch = torch.FloatTensor(y_train)
y_torch = y_torch.view(y_torch.size()[0], 1)
class LinearRegression(torch.nn.Module): #定义LR的类。torch.nn库构建模型
#PyTorch 的 nn 库中有大量有用的模块,其中一个就是线性模块。如名字所示,它对输入执行线性变换,即线性回归。
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
self.linear = torch.nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
model = LinearRegression(1, 1)
criterion = torch.nn.MSELoss() #训练线性回归,我们需要从 nn 库中添加合适的损失函数。对于线性回归,我们将使用 MSELoss()——均方差损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)#还需要使用优化函数(SGD),并运行与之前示例类似的反向传播。本质上,我们重复上文定义的 train() 函数中的步骤。
#不能直接使用该函数的原因是我们实现它的目的是分类而不是回归,以及我们使用交叉熵损失和最大元素的索引作为模型预测。而对于线性回归,我们使用线性层的输出作为预测。
for epoch in range(50):
data, target = Variable(x_torch), Variable(y_torch)
output = model(data)
optimizer.zero_grad()
loss = criterion(output, target)
loss.backward()
optimizer.step()
predicted = model(Variable(x_torch)).data.numpy()
#打印出原始数据和适合 PyTorch 的线性回归
plt.plot(x_train, y_train, 'o', label='Original data')
plt.plot(x_train, predicted, label='Fitted line')
plt.legend()
plt.title(u'Py:PyTorch实现简单合成数据集上的线性回归进行数据分析')
plt.show()
相关文章
- dbMigration .NET 数据同步迁移工具
- 8款数据迁移工具选型,主流且实用!
- 倍福TwinCAT(贝福Beckhoff)常见问题(FAQ)-可以用软件自带NC工具驱动但是程序无法让电机转动怎么办
- JAVA 校验身份证号码工具类(支持15位和18位)
- [工具]利用EasyRTSPClient工具检查摄像机RTSP流不能播放原因以及排查音视频数据无法播放问题
- .Net(c#)加密解密工具类:
- linux性能采用工具oprofile使用
- 单步调试学习NgRx createSelector 工具函数的使用方式
- Math/ML:时间序列数据集/时间序列预测任务的简介、常用算法及其工具、案例应用之详细攻略
- ML之FE:基于单个csv文件数据集(自动切分为两个dataframe表)利用featuretools工具实现自动特征生成/特征衍生
- CSDN:借助工具对【本博客访问来源】进行数据图表可视化(网友主要来自欧美和印度等)——记录数据来源截止日期20190811
- 这 5 款 Python 数据科学工具至少提效提升20%!
- 【阶段二】Python数据分析Pandas工具使用08篇:探索性数据分析:数据的描述:数据的分散趋势与数据的分布形态
- 【阶段二】Python数据分析Pandas工具使用06篇:探索性数据分析:异常数据的检测与处理
- 【阶段二】Python数据分析Pandas工具使用05篇:数据预处理:数据的规范化
- 【阶段二】Python数据分析Pandas工具使用03篇:数据预处理:多表合并与连接
- Hawk 数据抓取工具 使用说明(二)
- etlpy: 并行爬虫和数据清洗工具(开源)
- 003-and design-dva.js 知识导图-02-Reducer,Effect,Subscription,Router,dva配置,工具
- Python开发IDE工具PyCharm安装
- python工具方法 34 语义分割数据中类别频率统计及class weight计算
- python工具方法 25 txt标注(yolo格式标注)的目标检测文件转voc数据
- python工具方法 16 保存模型分类后的数据及分类错误的数据
- 【视频】React ReduxToolkit状态管理:创建store对象及redux调试工具的安装方法