共计 2314 个字符,预计需要花费 6 分钟才能阅读完成。
LSTM
A special RNN structure, which is very popular at present, effectively solves the problem of gradient explosion and long sequence memory of RNN
advantage
LSTM realizes the memory and forgetting of special features by introducing forgetting gate, input gate and output gate, so as to achieve better processing and memory effect of sequence data.
Schematic diagram:
Summary formula:
It's kind of a formula like this
Simply put, LSTM has three doors, input door, forget door, output door,
Are the degree parameters of the three gates respectively,
g is a regular RNN operation on the input.
You can see in the formula that LSTM has two outputs, cell state C'
And hidden state h'
c' is the product of the input and forget gate, that is, the content of the current cell itself, and gets h' through the output gate, that is, what content do you want to output to the next unit
So in practice, we don't care about the state of the cell itself, we want to get the state it presents
h' as the final output.
realize
Implement lstm manually with pytorch
Construction formula
class myLstm(nn.Module):
def __init__(self,input_sz,hidden_sz):
super().__init__()
self.input_size=input_sz
self.hidden_size=hidden_sz
self.U_i=nn.Parameter(torch.Tensor(input_sz,hidden_sz))
self.V_i = nn.Parameter(torch.Tensor(hidden_sz,hidden_sz))
self.b_i = nn.parameter(torch.Tensor(hidden_sz))
#f_t
self.U_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_f = nn.Parameter(torch.Tensor(hidden_sz))
#c_t
self.U_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_c = nn.Parameter(torch.Tensor(hidden_sz))
#o_t
self.U_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_o = nn.Parameter(torch.Tensor(hidden_sz))
self.init_weights()
def forward(self,x,init_states=None):
bs,seq_sz,_=x.size()
hidden_seq=[]
if init_states is None:
h_t,c_t=(
torch.zeros(bs,self.hidden_size).to(x.device),
torch.zeros(bs,self.hidden_size).to(x.device)
)
else:
h_t, c_t = init_states
for t in range(seq_sz):
x_t = x[:, t, :]
i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i)
f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f)
g_t = torch.tanh(x_t @ self.U_c + h_t @ self.V_c + self.b_c)
o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)