1. 首页 > 资讯

Seq2Seq 终于把 算法搞懂了!!

Seq2Seq(Sequence-to-Sequence)模型是一种用于处理序列数据的神经网络架构,广泛应用于自然语言处理(NLP)任务,如机器翻译、文本生成、对话系统等。

它通过编码器-解码器架构将输入序列(如一个句子)映射到输出序列(另一个句子或序列)。

Seq2Seq 模型由两个主要部分组成。

编码器是一个循环神经网络(RNN)或其变体,如LSTM或GRU,用于接收输入序列并将其转换为一个固定大小的上下文向量。

编码器逐步处理输入序列的每个时间步,通过隐藏层状态不断更新输入信息的表示,直到编码到达输入序列的结尾。

这一过程的最后一个隐藏状态通常被认为是整个输入序列的摘要,传递给解码器。

class Encoder(nn.Module):def __init__(self,input_dim,embedding_dim,hidden_size,num_layers,dropout):super(Encoder,self).__init__()#note hidden size and num layersself.hidden_size = hidden_sizeself.num_layers = num_layers#create a dropout layerself.dropout = nn.Dropout(dropout)#embedding to convert input token into dense vectorsself.embedding = nn.Embedding(input_dim,embedding_dim)#bilstm layerself.lstm = nn.LSTM(embedding_dim,hidden_size,num_layers=num_layers,bidirectinotallow=True,dropout=dropout)def forward(self,src):embedded = self.dropout(self.embedding(src))out,(hidden,cell) = self.lstm(embedded)return hidden,cell

解码器也是一个RNN网络,接受编码器输出的上下文向量,并生成目标序列。

解码器在每一步会生成一个输出,并将上一步的输出作为下一步的输入,直到产生特定的终止符。

解码器的初始状态来自编码器的最后一个隐藏状态,因此可以理解为解码器是基于编码器生成的全局信息来预测输出序列。

class Decoder(nn.Module):def __init__(self,output_dim,embedding_dim,hidden_size,num_layers,dropout):super(Decoder,self).__init__()self.output_dim = output_dim#note hidden size and num layers for seq2seq classself.hidden_size = hidden_sizeself.num_layers = num_layersself.dropout = nn.Dropout(dropout)#note inputs of embedding layerself.embedding = nn.Embedding(output_dim,embedding_dim)self.lstm = nn.LSTM(embedding_dim,hidden_size,num_layers=num_layers,bidirectinotallow=True,dropout=dropout)#we apply softmax over target vocab sizeself.fc = nn.Linear(hidden_size*2,output_dim)def forward(self,input_token,hidden,cell):#adjust dimensions of input tokeninput_token = input_token.unsqueeze(0)emb = self.embedding(input_token)emb = self.dropout(emb)#note hidden and cell along with outputout,(hidden,cell) = self.lstm(emb,(hidden,cell))out = out.squeeze(0)pred = self.fc(out)return pred,hidden,cell

Seq2Seq 模型的基本工作流程如下

下面是一个使用 Seq2Seq 进行机器翻译的示例代码。

首先,我们从 HuggingFace 导入了数据集,并将其分为训练集和测试集

import numpy as npimport pandas as pdimport seaborn as snsimport matplotlib.pyplot as pltimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import>

加载源语言和目标语言的 spaCy 模型。

spaCy 是一个功能强大、可用于生产的 Python 高级自然语言处理库。

与许多其他 NLP 库不同,spaCy 专为实际使用而设计,而非研究实验。

它擅长使用预先训练的模型进行高效的文本处理,可完成标记化、词性标记、命名实体识别和依赖性解析等任务。

en_nlp = spacy.load('en_core_web_sm')de_nlp = spacy.load('de_core_news_sm')#tokenizerdef sample_tokenizer(sample,en_nlp,de_nlp,lower,max_length,sos_token,eos_token):en_tokens = [token.text for token in en_nlp.tokenizer(sample["en"])][:max_length]de_tokens = [token.text for token in de_nlp.tokenizer(sample["de"])][:max_length]if lower == True:en_tokens = [token.lower() for token in en_tokens]de_tokens = [token.lower() for token in de_tokens]en_tokens = [sos_token] + en_tokens + [eos_token]de_tokens = [sos_token] + de_tokens + [eos_token]return {"en_tokens":en_tokens,"de_tokens":de_tokens}fn_kwargs = {"en_nlp":en_nlp,"de_nlp":de_nlp,"lower":True,"max_length":1000,"sos_token":'<sos>',"eos_token":'<eos>'}train_data = train_data.map(sample_tokenizer,fn_kwargs=fn_kwargs)val_data = val_data.map(sample_tokenizer,fn_kwargs=fn_kwargs)test_data = test_data.map(sample_tokenizer,fn_kwargs=fn_kwargs)min_freq = 2specials = ['<unk>','<pad>','<sos>','<eos>']en_vocab = build_vocab_from_iterator(train_data['en_tokens'],specials=specials,min_freq=min_freq)de_vocab = build_vocab_from_iterator(train_data['de_tokens'],specials=specials,min_freq=min_freq)assert en_vocab['<unk>'] == de_vocab['<unk>']assert en_vocab['<pad>'] == de_vocab['<pad>']unk_index = en_vocab['<unk>']pad_index = en_vocab['<pad>']en_vocab.set_default_index(unk_index)de_vocab.set_default_index(unk_index)def sample_num(sample,en_vocab,de_vocab):en_ids = en_vocab.lookup_indices(sample["en_tokens"])de_ids = de_vocab.lookup_indices(sample["de_tokens"])return {"en_ids":en_ids,"de_ids":de_ids}fn_kwargs = {"en_vocab":en_vocab,"de_vocab":de_vocab}train_data = train_data.map(sample_num,fn_kwargs=fn_kwargs)val_data = val_data.map(sample_num,fn_kwargs=fn_kwargs)test_data = test_data.map(sample_num,fn_kwargs=fn_kwargs)train_data = train_data.with_format(type="torch",columns=['en_ids','de_ids'],output_all_columns=True)val_data = val_data.with_format(type="torch",columns=['en_ids','de_ids'],output_all_columns=True)test_data = test_data.with_format(type="torch",columns=['en_ids','de_ids'],output_all_columns=True)def get_collate_fn(pad_index):def collate_fn(batch):batch_en_ids = [sample["en_ids"] for sample in batch]batch_de_ids = [sample["de_ids"] for sample in batch]batch_en_ids = pad_sequence(batch_en_ids,padding_value=pad_index)batch_de_ids = pad_sequence(batch_de_ids,padding_value=pad_index)batch = {"en_ids":batch_en_ids,"de_ids":batch_de_ids}return batchreturn collate_fndef get_dataloader(dataset,batch_size,shuffle,pad_index):collate_fn = get_collate_fn(pad_index)dataloader =>

本网站的文章部分内容可能来源于网络和网友发布,仅供大家学习与参考,如有侵权,请联系站长进行删除处理,不代表本网站立场,转载者并注明出处:https://www.jmbhsh.com/zixun/31515.html

联系我们

QQ号:***

微信号:***

工作日:9:30-18:30,节假日休息