博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch
阅读量:4107 次
发布时间:2019-05-25

本文共 1767 字,大约阅读时间需要 5 分钟。

学习莫烦的pytorch视频,对一些函数进行搜索,记录功能以及测试

import torchimport numpy as npfrom torch.autograd import Variableimport torch.nn.functional as Fimport matplotlib.pyplot as pltimport pdbx=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)#增加dim个维数,类似 int a=0==>int a[1][1]={0}y=x.pow(2)+0.2*torch.rand(x.size())x, y =Variable(x), Variable(y)#tensor2Variable#plt.scatter(x.data.numpy(), y.data.numpy())#散点图#plt.show()class Net(torch.nn.Module):#继承torch的这个模块    def __init__(self, n_features, n_hidden, n_output):#层的信息        super(Net,self).__init__()        self.hidden = torch.nn.Linear(n_features, n_hidden)        #self.hidden2 = torch.nn.Linear(n_hidden,n_hidden)        self.predict = torch.nn.Linear(n_hidden, n_output)    def forward(self,x):        x=F.relu(self.hidden(x))#效果最好        #x=torch.sigmoid(self.hidden(x))        #x=torch.tanh(self.hidden(x))        #x=F.relu(self.hidden2(x))        x=self.predict(x)        return xnet = Net(1, 10, 1)#print(net)#打印网络结构plt.ion()#打开交互模式#optimizer = torch.optim.SGD(net.parameters(), lr=0.5)#优化方法loss_func = torch.nn.MSELoss()#均方差损失函数'''t=0for i in net.parameters():    print(t,'\n')    t=t+1    print(i)'''for t in range(200):    l=0.5    if(t>100&&l>=0.5):        l=l/10    optimizer = torch.optim.SGD(net.parameters(), lr=l)#优化方法    prediction = net(x)    #pdb.set_trace()    loss = loss_func(prediction, y)#预测值,真实值    optimizer.zero_grad()#梯度置0,梯度随着loss的变化而变化,每个epoch都不一样    loss.backward()#反向传播,计算梯度    optimizer.step()#根据梯度,学习率更新权重参数    if t%5==0:        plt.cla()#清除当前axes对象        plt.scatter(x.data.numpy(), y.data.numpy())        plt.plot(x.data.numpy(), prediction.data.numpy())        plt.text(0.5, 0, 'Loss=%.4f'%loss.item(), fontdict={
'size':20, 'color':'red'}) plt.pause(0.1)#暂停plt.ioff()#关闭交互模式plt.show()

转载地址:http://twksi.baihongyu.com/

你可能感兴趣的文章
Week 1 Functional Language
查看>>
LOJ2565 SDOI2018 旧试题 莫比乌斯反演、三元环计数
查看>>
深入理解Java内存模型——volatile
查看>>
使用 nm-applet 连接 WPA2-Enterprise wireless
查看>>
Linq学习笔记之一:Linq To XML
查看>>
TCP与UDP
查看>>
struts原理图
查看>>
2018.07.08 NOIP模拟 好数(线段树)
查看>>
nyoj 城市平乱(Dijkstra)
查看>>
安装Java EE失败,解决方案
查看>>
MFC 中消息循环实例
查看>>
solr默认查询设置
查看>>
多线程经典编程题 实践篇
查看>>
一个bug
查看>>
c#水晶报表教程
查看>>
NativeScript的开发体会
查看>>
Python模块Scrapy导入出错:ImportError: cannot import name xmlrpc_client
查看>>
基于已构建S2SH项目配置全注解方式简化配置文件
查看>>
DOM访问和处理HTML文档的标准方法
查看>>
class 的使用
查看>>