首先先放原文链接:
官方没给代码,但是还是有非官方大佬实现了xlstm,不过我找到的代码似乎都是同一套,并且没有空格分隔,不能直接复制使用。因此我用ai调整了一下代码格式,可以直接复制使用。
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalConv1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
super(CausalConv1D, self).__init__()
self.padding = (kernel_size - 1) * dilation
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.padding, dilation=dilation, **kwargs)
def forward(self, x):
x = self.conv(x)
return x[:, :, :-self.padding]
class BlockDiagonal(nn.Module):
def __init__(self, in_features, out_features, num_blocks):
super(BlockDiagonal, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.num_blocks = num_blocks
assert in_features % num_blocks == 0
assert out_features % num_blocks == 0
block_in_features = in_features // num_blocks
block_out_features = out_features // num_blocks
self.blocks = nn.ModuleList([
nn.Linear(block_in_features, block_out_features)
for _ in range(num_blocks)
])
def forward(self, x):
x = x.chunk(self.num_blocks, dim=-1)
x = [block(x_i) for block, x_i in zip(self.blocks, x)]
x = torch.cat(x, dim=-1)
return x
class sLSTMBlock(nn.Module):
def __init__(self, input_size, hidden_size, num_heads, proj_factor=4/3):
super(sLSTMBlock, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.proj_factor = proj_factor
assert hidden_size % num_heads == 0
assert proj_factor > 0
self.layer_norm = nn.LayerNorm(input_size)
self.causal_conv = CausalConv1D(1, 1, 4)
self.Wz = BlockDiagonal(input_size, hidden_size, num_heads)
self.Wi = BlockDiagonal(input_size, hidden_size, num_heads)
self.Wf = BlockDiagonal(input_size, hidden_size, num_heads)
self.Wo = BlockDiagonal(input_size, hidden_size, num_heads)
self.Rz = BlockDiagonal(hidden_size, hidden_size, num_heads)
self.Ri = BlockDiagonal(hidden_size, hidden_size, num_heads)
self.Rf = BlockDiagonal(hidden_size, hidden_size, num_heads)
self.Ro = BlockDiagonal(hidden_size, hidden_size, num_heads)
self.group_norm = nn.GroupNorm(num_heads, hidden_size)
self.up_proj_left = nn.Linear(hidden_size, int(hidden_size * proj_factor))
self.up_proj_right = nn.Linear(hidden_size, int(hidden_size * proj_factor))
self.down_proj = nn.Linear(int(hidden_size * proj_factor), input_size)
def forward(self, x, prev_state):
assert x.size(-1) == self.input_size
h_prev, c_prev, n_prev, m_prev = prev_state
x_norm = self.layer_norm(x)
x_conv = F.silu(self.causal_conv(x_norm.unsqueeze(1)).squeeze(1))
z = torch.tanh(self.Wz(x) + self.Rz(h_prev))
o = torch.sigmoid(self.Wo(x) + self.Ro(h_prev))
i_tilde = self.Wi(x_conv) + self.Ri(h_prev)
f_tilde = self.Wf(x_conv) + self.Rf(h_prev)
m_t = torch.max(f_tilde + m_prev, i_tilde)
i = torch.exp(i_tilde - m_t)
f = torch.exp(f_tilde + m_prev - m_t)
c_t = f * c_prev + i * z
n_t = f * n_prev + i
h_t = o * c_t / n_t
output = h_t
output_norm = self.group_norm(output)
output_left = self.up_proj_left(output_norm)
output_right = self.up_proj_right(output_norm)
output_gated = F.gelu(output_right)
output = output_left * output_gated
output = self.down_proj(output)
final_output = output + x
return final_output, (h_t, c_t, n_t, m_t)
class sLSTM(nn.Module):
# TODO: Add bias, dropout, bidirectional
def __init__(self, input_size, hidden_size, num_heads, num_layers=1, batch_first=False, proj_factor=4/3):
super(sLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_layers = num_layers
self.batch_first = batch_first
self.proj_factor_slstm = proj_factor
self.layers = nn.ModuleList([sLSTMBlock(input_size, hidden_size, num_heads, proj_factor) for _ in range(num_layers)])
def forward(self, x, state=None):
assert x.ndim == 3
if self.batch_first: x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if state is not None:
state = torch.stack(list(state))
assert state.ndim == 4
num_hidden, state_num_layers, state_batch_size, state_input_size = state.size()
assert num_hidden == 4
assert state_num_layers == self.num_layers
assert state_batch_size == batch_size
assert state_input_size == self.input_size
state = state.transpose(0, 1)
else:
state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size)
output = []
for t in range(seq_len):
x_t = x[t]
for layer in range(self.num_layers):
x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone()))
state[layer] = torch.stack(list(state_tuple))
output.append(x_t)
output = torch.stack(output)
if self.batch_first:
output = output.transpose(0, 1)
state = tuple(state.transpose(0, 1))
return output, state
class mLSTMBlock(nn.Module):
def __init__(self, input_size, hidden_size, num_heads, proj_factor=2):
super(mLSTMBlock, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.proj_factor = proj_factor
assert hidden_size % num_heads == 0
assert proj_factor > 0
self.layer_norm = nn.LayerNorm(input_size)
self.up_proj_left = nn.Linear(input_size, int(input_size * proj_factor))
self.up_proj_right = nn.Linear(input_size, hidden_size)
self.down_proj = nn.Linear(hidden_size, input_size)
self.causal_conv = CausalConv1D(1, 1, 4)
self.skip_connection = nn.Linear(int(input_size * proj_factor), hidden_size)
self.Wq = BlockDiagonal(int(input_size * proj_factor), hidden_size, num_heads)
self.Wk = BlockDiagonal(int(input_size * proj_factor), hidden_size, num_heads)
self.Wv = BlockDiagonal(int(input_size * proj_factor), hidden_size, num_heads)
self.Wi = nn.Linear(int(input_size * proj_factor), hidden_size)
self.Wf = nn.Linear(int(input_size * proj_factor), hidden_size)
self.Wo = nn.Linear(int(input_size * proj_factor), hidden_size)
self.group_norm = nn.GroupNorm(num_heads, hidden_size)
def forward(self, x, prev_state):
h_prev, c_prev, n_prev, m_prev = prev_state
assert x.size(-1) == self.input_size
x_norm = self.layer_norm(x)
x_up_left = self.up_proj_left(x_norm)
x_up_right = self.up_proj_right(x_norm)
x_conv = F.silu(self.causal_conv(x_up_left.unsqueeze(1)).squeeze(1))
x_skip = self.skip_connection(x_conv)
q = self.Wq(x_conv)
k = self.Wk(x_conv) / (self.head_size ** 0.5)
v = self.Wv(x_up_left)
i_tilde = self.Wi(x_conv)
f_tilde = self.Wf(x_conv)
o = torch.sigmoid(self.Wo(x_up_left))
m_t = torch.max(f_tilde + m_prev, i_tilde)
i = torch.exp(i_tilde - m_t)
f = torch.exp(f_tilde + m_prev - m_t)
c_t = f * c_prev + i * (v * k) # v @ k.T
n_t = f * n_prev + i * k
h_t = o * (c_t * q) / torch.max(torch.abs(n_t.T @ q), 1)[0] # o * (c @ q) / max{|n.T @ q|, 1}
output = h_t
output_norm = self.group_norm(output)
output = output_norm + x_skip
output = output * F.silu(x_up_right)
output = self.down_proj(output)
final_output = output + x
return final_output, (h_t, c_t, n_t, m_t)
class mLSTM(nn.Module):
# TODO: Add bias, dropout, bidirectional
def __init__(self, input_size, hidden_size, num_heads, num_layers=1, batch_first=False, proj_factor=2):
super(mLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_layers = num_layers
self.batch_first = batch_first
self.proj_factor_slstm = proj_factor
self.layers = nn.ModuleList([mLSTMBlock(input_size, hidden_size, num_heads, proj_factor) for _ in range(num_layers)])
def forward(self, x, state=None):
assert x.ndim == 3
if self.batch_first: x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if state is not None:
state = torch.stack(list(state))
assert state.ndim == 4
num_hidden, state_num_layers, state_batch_size, state_input_size = state.size()
assert num_hidden == 4
assert state_num_layers == self.num_layers
assert state_batch_size == batch_size
assert state_input_size == self.input_size
state = state.transpose(0, 1)
else:
state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size)
output = []
for t in range(seq_len):
x_t = x[t]
for layer in range(self.num_layers):
x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone()))
state[layer] = torch.stack(list(state_tuple))
output.append(x_t)
output = torch.stack(output)
if self.batch_first:
output = output.transpose(0, 1)
state = tuple(state.transpose(0, 1))
return output, state
class xLSTM(nn.Module):
# TODO: Add bias, dropout, bidirectional
def __init__(self, input_size, hidden_size, num_heads, layers, batch_first=False, proj_factor_slstm=4/3, proj_factor_mlstm=2):
super(xLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.layers = layers
self.num_layers = len(layers)
self.batch_first = batch_first
self.proj_factor_slstm = proj_factor_slstm
self.proj_factor_mlstm = proj_factor_mlstm
self.layers = nn.ModuleList()
for layer_type in layers:
if layer_type == 's':
layer = sLSTMBlock(input_size, hidden_size, num_heads, proj_factor_slstm)
elif layer_type == 'm':
layer = mLSTMBlock(input_size, hidden_size, num_heads, proj_factor_mlstm)
else:
raise ValueError(f"Invalid layer type: {layer_type}. Choose 's' for sLSTM or 'm' for mLSTM.")
self.layers.append(layer)
def forward(self, x, state=None):
assert x.ndim == 3
if self.batch_first: x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if state is not None:
state = torch.stack(list(state))
assert state.ndim == 4
num_hidden, state_num_layers, state_batch_size, state_input_size = state.size()
assert num_hidden == 4
assert state_num_layers == self.num_layers
assert state_batch_size == batch_size
assert state_input_size == self.input_size
state = state.transpose(0, 1)
else:
state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size)
output = []
for t in range(seq_len):
x_t = x[t]
for layer in range(self.num_layers):
x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone()))
state[layer] = torch.stack(list(state_tuple))
output.append(x_t)
output = torch.stack(output)
if self.batch_first:
output = output.transpose(0, 1)
state = tuple(state.transpose(0, 1))
return output, state
因篇幅问题不能全部显示,请点此查看更多更全内容