{
"cells": [
{
"cell_type": "markdown",
"id": "2fd262de",
"metadata": {},
"source": [
"# GRU Encoder with Self-Attention Hybrid\n",
"\n",
"This notebook keeps the multi-head decoder from the previous step but augments the GRU encoder with a **self-attention block**.\n",
"The encoder first runs a stacked GRU and then applies multi-head self-attention plus Add&Norm to capture long-range source\n",
"relationships before handing contextual states to the decoder.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79b4fb14",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
"import os\n",
"import math\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l\n",
"from tsv_seq2seq_data import TSVSeq2SeqData\n",
"\n",
"import importlib\n",
"import hw7\n",
"importlib.reload(hw7)\n",
"from hw7 import *\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3102dc75",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data_path = os.path.expanduser('~/Dropbox/CS6140/data/sentence_pairs_large.tsv')\n",
"data = TSVSeq2SeqData(\n",
" path=data_path,\n",
" batch_size=512,\n",
" num_steps=25,\n",
" min_freq=2,\n",
" val_frac=0.05,\n",
" test_frac=0.0,\n",
" sample_percent=1,\n",
")\n",
"\n",
"embed_size = 256\n",
"num_hiddens = 320 \n",
"num_blks = 3 \n",
"num_layers =3\n",
"dropout = 0.35 \n",
"num_heads = 4\n",
"\n",
"encoder = SelfAttentionAugmentedEncoder(len(data.src_vocab), embed_size, num_hiddens, num_layers,\n",
" num_heads=num_heads, dropout=dropout)\n",
"decoder = MultiHeadSeq2SeqDecoder(len(data.tgt_vocab), embed_size,\n",
" num_hiddens, num_layers, num_heads=num_heads, dropout=dropout)\n",
"model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab[''], lr=0.0015)\n",
"trainer = d2l.Trainer(max_epochs=15, gradient_clip_val=1, num_gpus=1)\n",
"trainer.fit(model, data)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4a4e557d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"vamos . => let's go on. | reference: go . BLEU: 0.000\n",
"me perdi . => i caught a wheat in the wind. | reference: i got lost . BLEU: 0.000\n",
"esta tranquilo . => we worked through this evening. | reference: he is calm . BLEU: 0.000\n",
"estoy en casa . => i'm home at all. | reference: i am at home . BLEU: 0.000\n",
"donde esta el tren ? => where this train goes on? | reference: where is the train ? BLEU: 0.000\n",
"necesito ayuda urgente . => i require urgent help. | reference: i need urgent help . BLEU: 0.000\n",
"ayer llovio mucho en la ciudad . => yesterday long rained in the city was in 1992. | reference: it rained a lot in the city yesterday . BLEU: 0.527\n",
"los ninos estan jugando en el parque . => the played in the park together. | reference: the children are playing in the park . BLEU: 0.498\n",
"ella quiere aprender a hablar ingles muy bien . => she wants to learn to speak very fast, very well. | reference: she wants to learn to speak english very well . BLEU: 0.722\n",
"cuando llegara el proximo tren a madrid ? => when the train stops, call that bird is rotten. | reference: when will the next train to madrid arrive ? BLEU: 0.000\n"
]
}
],
"source": [
"# examples = ['vamos .', 'me perdi .', 'esta tranquilo .', 'estoy en casa .', 'donde esta el tren ?']\n",
"# references = ['go .', 'i got lost .', 'he is calm .', 'i am at home .', 'where is the train ?']\n",
"\n",
"# preds, _ = model.predict_step(\n",
"# data.build(examples, references), d2l.try_gpu(), data.num_steps)\n",
"# for src, tgt, pred in zip(examples, references, preds):\n",
"# translation = []\n",
"# for token in data.tgt_vocab.to_tokens(pred):\n",
"# if token == '':\n",
"# break\n",
"# translation.append(token)\n",
"# print(f\"{src} => {' '.join(translation)} | reference: {tgt}\")\n",
"\n",
"\n",
"\n",
"examples = ['vamos .', 'me perdi .', 'esta tranquilo .', 'estoy en casa .', 'donde esta el tren ?', 'necesito ayuda urgente .',\n",
" 'ayer llovio mucho en la ciudad .', 'los ninos estan jugando en el parque .', 'ella quiere aprender a hablar ingles muy bien .',\n",
" 'cuando llegara el proximo tren a madrid ?']\n",
"\n",
"references = ['go .', 'i got lost .', 'he is calm .', 'i am at home .', 'where is the train ?',\n",
" 'i need urgent help .', 'it rained a lot in the city yesterday .',\n",
" 'the children are playing in the park .', 'she wants to learn to speak english very well .', 'when will the next train to madrid arrive ?']\n",
"\n",
"preds, _ = model.predict_step(\n",
" data.build(examples, references), d2l.try_gpu(), data.num_steps)\n",
"for src, tgt, pred in zip(examples, references, preds):\n",
" translation = []\n",
" for token in data.tgt_vocab.to_tokens(pred):\n",
" if token == '':\n",
" break\n",
" translation.append(token)\n",
" \n",
" hypo = ' '.join(translation)\n",
" print(f\"{src} => {hypo} | reference: {tgt} BLEU: {d2l.bleu(hypo, tgt, k=2):.3f}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9934d874",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"vamos . => we are going to be rainy. now. | reference: go . | BLEU: 0.000\n",
"me perdi . => forgive me for me. i have no longer with it. | reference: i got lost . | BLEU: 0.000\n",
"esta tranquilo . => this server is this term. and i have to be ashamed of the way. do you need to do? | reference: he is calm . | BLEU: 0.000\n",
"estoy en casa . => i am at home by the other. i have a weak one. than i are in a 1000 animal. | reference: i am at home . | BLEU: 0.167\n",
"donde esta el tren ? => i remember this place where i could. it. | reference: where is the train ? | BLEU: 0.000\n",
"necesito ayuda urgente . => i need a lot of help. actually. i have no more than 50 employees. than to do a few words maybe? please? | reference: i need urgent help . | BLEU: 0.089\n",
"ayer llovio mucho en la ciudad . => yesterday yesterday yesterday i was still in a week. | reference: it rained a lot in the city yesterday . | BLEU: 0.000\n",
"los ninos estan jugando en el parque . => i the on the sidewalk. in the age of speaking. | reference: the children are playing in the park . | BLEU: 0.258\n",
"ella quiere aprender a hablar ingles muy bien . => she wants to become very good and speaking to herself. | reference: she wants to learn to speak english very well . | BLEU: 0.168\n",
"cuando llegara el proximo tren a madrid ? => when i go to the united states, i know about it. it was a way of a week. it should be over with. with. | reference: when will the next train to madrid arrive ? | BLEU: 0.000\n"
]
}
],
"source": [
"for src, tgt in zip(examples, references):\n",
" src_sentence = src.lower().split()\n",
" src_tokens = [data.src_vocab[token] for token in src_sentence]\n",
" pred_ids = beam_search_translate(model, src_tokens, data, beam_size=5, max_steps=40)\n",
" translation = data.tgt_vocab.to_tokens(pred_ids)\n",
" hypo = ' '.join(translation)\n",
" print(f\"{src} => {hypo} | reference: {tgt} | BLEU: {d2l.bleu(hypo, tgt, k=2):.3f}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python(titanic2_d2lbook_env)",
"language": "python",
"name": "titanic2_d2lbook_env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}