pytorch自定義不可導激活函數
今天自定義不可導函數的時候遇到了一個大坑。
首先我需要自定義一個函數:sign_f
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_-1.] = 0
return grad_output
然后我需要把它封裝為一個module 類型,就像 nn.Conv2d 模塊 封裝 f.conv2d 一樣,于是
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
# 我需要的module
def __init__(self, *kargs, **kwargs):
super(sign_, self).__init__(*kargs, **kwargs)
def forward(self, inputs):
# 使用自定義函數
outs = sign_f(inputs)
return outs
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_-1.] = 0
return grad_output
結果報錯
TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'
我試了半天,發現自定義函數后面要加 apply ,詳細見下面
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
def __init__(self, *kargs, **kwargs):
super(sign_, self).__init__(*kargs, **kwargs)
self.r = sign_f.apply ### -----注意此處
def forward(self, inputs):
outs = self.r(inputs)
return outs
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_-1.] = 0
return grad_output
問題解決了!
PyTorch自定義帶學習參數的激活函數(如sigmoid)
有的時候我們需要給損失函數設一個超參數但是又不想設固定閾值想和網絡一起自動學習,例如給Sigmoid一個參數alpha進行調節


函數如下:
import torch.nn as nn
import torch
class LearnableSigmoid(nn.Module):
def __init__(self, ):
super(LearnableSigmoid, self).__init__()
self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.fill_(1.0)
def forward(self, input):
return 1/(1 + torch.exp(-self.weight*input))
驗證和Sigmoid的一致性
class LearnableSigmoid(nn.Module):
def __init__(self, ):
super(LearnableSigmoid, self).__init__()
self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.fill_(1.0)
def forward(self, input):
return 1/(1 + torch.exp(-self.weight*input))
Sigmoid = nn.Sigmoid()
LearnSigmoid = LearnableSigmoid()
input = torch.tensor([[0.5289, 0.1338, 0.3513],
[0.4379, 0.1828, 0.4629],
[0.4302, 0.1358, 0.4180]])
print(Sigmoid(input))
print(LearnSigmoid(input))
輸出結果
tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]])
tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]], grad_fn=MulBackward0>)
驗證權重是不是會更新
import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
def __init__(self, ):
super(LearnableSigmoid, self).__init__()
self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.fill_(1.0)
def forward(self, input):
return 1/(1 + torch.exp(-self.weight*input))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.LSigmoid = LearnableSigmoid()
def forward(self, x):
x = self.LSigmoid(x)
return x
net = Net()
print(list(net.parameters()))
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)
for i in range(2):
optimizer.zero_grad()
output = net(input_data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(list(net.parameters()))
輸出結果
tensor([1.], requires_grad=True)]
[Parameter containing:
tensor([0.9979], requires_grad=True)]
[Parameter containing:
tensor([0.9958], requires_grad=True)]
會更新~
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
您可能感興趣的文章:- pytorch方法測試——激活函數(ReLU)詳解
- PyTorch中常用的激活函數的方法示例
- Pytorch 實現自定義參數層的例子