超全总结!Pythorch 构建Attention-lstm相关模型!!
时序数据分析在预测未来事件、检测异常、识别模式等领域中广泛应用。
因此,下面将详细介绍如何使用PyTorch框架构建一个基于Attention机制的LSTM(长短期记忆网络)模型来处理时序数据。
原理阐述
LSTM网络
咱们先聊基础,关于LSTM在此前夜讲过很多了。大家可以翻回去看看~
LSTM是一种特殊的RNN(循环神经网络),适用于处理和预测基于时间的数据。它通过三个门(输入门、遗忘门、输出门)来控制信息的流动,从而能够学习长期依赖关系。
LSTM的核心公式如下:
- 遗忘门:决定当前时刻 遗忘多少先前状态信息。
输入门:决定当前时刻 添加多少新信息。
候选记忆单元:生成新的候选信息。
当前记忆单元:综合遗忘信息和新输入信息。
输出门:决定输出多少信息。
最终输出:
其中, 是sigmoid函数, 是逐元素乘积。
Attention机制
Attention机制在处理长序列时特别有用,因为它可以帮助模型关注序列中的重要部分。它通过计算不同时间步的加权和来突出不同的输入数据对输出的重要性。
Attention机制的核心公式:
- 计算注意力权重:
其中,score函数可以是点积、双线性或其他相似性测量方法。
- 归一化权重:
加权和:
最终,Attention机制的输出是输入时间步的加权和,能使模型更有效地关注重要的信息。
模型训练
为了演示Attention-LSTM在时序预测中的应用,我们使用一个模拟的时序数据集进行预测。
Pytorch代码实现
下面是使用PyTorch构建Attention-LSTM模型的代码示例。我们使用一个简单的正弦波数据集来说明。
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 创建模拟数据
def create_sin_wave(seq_len, n_samples):
x = np.linspace(0, 50, n_samples)
data = np.sin(x)
return data
# Attention机制实现
class Attention(nn.Module):
def __init__(self, hidden_dim):
super(Attention, self).__init__()
self.hidden_dim = hidden_dim
self.attn = nn.Linear(hidden_dim, hidden_dim)
self.context = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, hidden_states):
attn_weights = torch.tanh(self.attn(hidden_states))
attn_weights = self.context(attn_weights).squeeze(2)
attn_weights = torch.softmax(attn_weights, dim=1)
context_vector = torch.sum(attn_weights.unsqueeze(2) * hidden_states, dim=1)
return context_vector, attn_weights
# LSTM模型实现
class AttentionLSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, n_layers):
super(AttentionLSTM, self).__init__()
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.lstm = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True)
self.attention = Attention(hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
h_0 = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).to(x.device)
c_0 = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).to(x.device)
out, _ = self.lstm(x, (h_0, c_0))
context_vector, attn_weights = self.attention(out)
out = self.fc(context_vector)
return out, attn_weights
# 生成数据集
seq_len = 20
n_samples = 1000
data = create_sin_wave(seq_len, n_samples)
data = torch.tensor(data, dtype=torch.float32).unsqueeze(1)
# 准备训练集和测试集
def create_inout_sequences(data, seq_len):
inout_seq = []
L = len(data)
for i in range(L-seq_len):
train_seq = data[i:i+seq_len]
train_label = data[i+seq_len:i+seq_len+1]
inout_seq.append((train_seq, train_label))
return inout_seq
train_seq = create_inout_sequences(data, seq_len)
train_X = torch.stack([s[0] for s in train_seq])
train_Y = torch.stack([s[1] for s in train_seq])
# 训练模型
input_dim = 1
hidden_dim = 64
output_dim = 1
n_layers = 2
n_epochs = 100
learning_rate = 0.001
model = AttentionLSTM(input_dim, hidden_dim, output_dim, n_layers)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(n_epochs):
optimizer.zero_grad()
output, attn_weights = model(train_X)
loss = criterion(output, train_Y)
loss.backward()
optimizer.step()
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}')
# 可视化结果
model.eval()
with torch.no_grad():
pred, attn_weights = model(train_X)
# 绘制实际值与预测值
plt.figure(figsize=(14, 7))
plt.plot(data.numpy(), label='True Data')
plt.plot(range(seq_len, seq_len + len(pred)), pred.numpy(), label='Predicted Data')
plt.xlabel('Time step')
plt.ylabel('Value')
plt.title('Attention-LSTM: True vs Predicted')
plt.legend()
plt.show()
# 绘制注意力权重
attn_weights = attn_weights.numpy()
plt.figure(figsize=(14, 7))
plt.imshow(attn_weights.T, aspect='auto', cmap='viridis')
plt.colorbar()
plt.xlabel('Time step')
plt.ylabel('Attention Weights')
plt.title('Attention Weights Distribution')
plt.show()
代码说明
- 数据生成:我们创建了一个正弦波数据集,用于模拟时序数据。
- Attention机制实现:定义了一个
Attention
类,用于计算注意力权重和上下文向量。 - LSTM模型实现:定义了一个
AttentionLSTM
类,该类包含LSTM层、Attention层和全连接层。 - 训练模型:通过梯度下降训练模型。
- 可视化结果:绘制预测值和实际值的对比图,以及注意力权重的分布图。
Attention机制能够增强LSTM模型的性能,使其在处理长时间依赖关系时更加有效。通过引入Attention机制,模型能够自动关注时序数据中重要的时间步,从而提高预测的准确性。
这种Attention-LSTM模型在金融预测、气象分析、医疗诊断等领域都有广泛的应用潜力。大家在未来的研究可以探索不同类型的注意力机制以及其在不同应用场景中的效果。
本文章转载微信公众号@深夜努力写Python