{ "cells": [ { "cell_type": "markdown", "id": "2fd262de", "metadata": {}, "source": [ "# GRU Encoder with Self-Attention Hybrid\n", "\n", "This notebook keeps the multi-head decoder from the previous step but augments the GRU encoder with a **self-attention block**.\n", "The encoder first runs a stacked GRU and then applies multi-head self-attention plus Add&Norm to capture long-range source\n", "relationships before handing contextual states to the decoder.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "79b4fb14", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "\n", "import os\n", "import math\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "from tsv_seq2seq_data import TSVSeq2SeqData\n", "\n", "import importlib\n", "import hw7\n", "importlib.reload(hw7)\n", "from hw7 import *\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "3102dc75", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-11-29T17:25:48.867075\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data_path = os.path.expanduser('~/Dropbox/CS6140/data/sentence_pairs_large.tsv')\n", "data = TSVSeq2SeqData(\n", " path=data_path,\n", " batch_size=512,\n", " num_steps=25,\n", " min_freq=2,\n", " val_frac=0.05,\n", " test_frac=0.0,\n", " sample_percent=1,\n", ")\n", "\n", "embed_size = 256\n", "num_hiddens = 320 \n", "num_blks = 3 \n", "num_layers =3\n", "dropout = 0.35 \n", "num_heads = 4\n", "\n", "encoder = SelfAttentionAugmentedEncoder(len(data.src_vocab), embed_size, num_hiddens, num_layers,\n", " num_heads=num_heads, dropout=dropout)\n", "decoder = MultiHeadSeq2SeqDecoder(len(data.tgt_vocab), embed_size,\n", " num_hiddens, num_layers, num_heads=num_heads, dropout=dropout)\n", "model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab[''], lr=0.0015)\n", "trainer = d2l.Trainer(max_epochs=15, gradient_clip_val=1, num_gpus=1)\n", "trainer.fit(model, data)\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "4a4e557d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "vamos . => let's go on. | reference: go . BLEU: 0.000\n", "me perdi . => i caught a wheat in the wind. | reference: i got lost . BLEU: 0.000\n", "esta tranquilo . => we worked through this evening. | reference: he is calm . BLEU: 0.000\n", "estoy en casa . => i'm home at all. | reference: i am at home . BLEU: 0.000\n", "donde esta el tren ? => where this train goes on? | reference: where is the train ? BLEU: 0.000\n", "necesito ayuda urgente . => i require urgent help. | reference: i need urgent help . BLEU: 0.000\n", "ayer llovio mucho en la ciudad . => yesterday long rained in the city was in 1992. | reference: it rained a lot in the city yesterday . BLEU: 0.527\n", "los ninos estan jugando en el parque . => the played in the park together. | reference: the children are playing in the park . BLEU: 0.498\n", "ella quiere aprender a hablar ingles muy bien . => she wants to learn to speak very fast, very well. | reference: she wants to learn to speak english very well . BLEU: 0.722\n", "cuando llegara el proximo tren a madrid ? => when the train stops, call that bird is rotten. | reference: when will the next train to madrid arrive ? BLEU: 0.000\n" ] } ], "source": [ "# examples = ['vamos .', 'me perdi .', 'esta tranquilo .', 'estoy en casa .', 'donde esta el tren ?']\n", "# references = ['go .', 'i got lost .', 'he is calm .', 'i am at home .', 'where is the train ?']\n", "\n", "# preds, _ = model.predict_step(\n", "# data.build(examples, references), d2l.try_gpu(), data.num_steps)\n", "# for src, tgt, pred in zip(examples, references, preds):\n", "# translation = []\n", "# for token in data.tgt_vocab.to_tokens(pred):\n", "# if token == '':\n", "# break\n", "# translation.append(token)\n", "# print(f\"{src} => {' '.join(translation)} | reference: {tgt}\")\n", "\n", "\n", "\n", "examples = ['vamos .', 'me perdi .', 'esta tranquilo .', 'estoy en casa .', 'donde esta el tren ?', 'necesito ayuda urgente .',\n", " 'ayer llovio mucho en la ciudad .', 'los ninos estan jugando en el parque .', 'ella quiere aprender a hablar ingles muy bien .',\n", " 'cuando llegara el proximo tren a madrid ?']\n", "\n", "references = ['go .', 'i got lost .', 'he is calm .', 'i am at home .', 'where is the train ?',\n", " 'i need urgent help .', 'it rained a lot in the city yesterday .',\n", " 'the children are playing in the park .', 'she wants to learn to speak english very well .', 'when will the next train to madrid arrive ?']\n", "\n", "preds, _ = model.predict_step(\n", " data.build(examples, references), d2l.try_gpu(), data.num_steps)\n", "for src, tgt, pred in zip(examples, references, preds):\n", " translation = []\n", " for token in data.tgt_vocab.to_tokens(pred):\n", " if token == '':\n", " break\n", " translation.append(token)\n", " \n", " hypo = ' '.join(translation)\n", " print(f\"{src} => {hypo} | reference: {tgt} BLEU: {d2l.bleu(hypo, tgt, k=2):.3f}\")\n", " " ] }, { "cell_type": "code", "execution_count": 4, "id": "9934d874", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "vamos . => we are going to be rainy. now. | reference: go . | BLEU: 0.000\n", "me perdi . => forgive me for me. i have no longer with it. | reference: i got lost . | BLEU: 0.000\n", "esta tranquilo . => this server is this term. and i have to be ashamed of the way. do you need to do? | reference: he is calm . | BLEU: 0.000\n", "estoy en casa . => i am at home by the other. i have a weak one. than i are in a 1000 animal. | reference: i am at home . | BLEU: 0.167\n", "donde esta el tren ? => i remember this place where i could. it. | reference: where is the train ? | BLEU: 0.000\n", "necesito ayuda urgente . => i need a lot of help. actually. i have no more than 50 employees. than to do a few words maybe? please? | reference: i need urgent help . | BLEU: 0.089\n", "ayer llovio mucho en la ciudad . => yesterday yesterday yesterday i was still in a week. | reference: it rained a lot in the city yesterday . | BLEU: 0.000\n", "los ninos estan jugando en el parque . => i the on the sidewalk. in the age of speaking. | reference: the children are playing in the park . | BLEU: 0.258\n", "ella quiere aprender a hablar ingles muy bien . => she wants to become very good and speaking to herself. | reference: she wants to learn to speak english very well . | BLEU: 0.168\n", "cuando llegara el proximo tren a madrid ? => when i go to the united states, i know about it. it was a way of a week. it should be over with. with. | reference: when will the next train to madrid arrive ? | BLEU: 0.000\n" ] } ], "source": [ "for src, tgt in zip(examples, references):\n", " src_sentence = src.lower().split()\n", " src_tokens = [data.src_vocab[token] for token in src_sentence]\n", " pred_ids = beam_search_translate(model, src_tokens, data, beam_size=5, max_steps=40)\n", " translation = data.tgt_vocab.to_tokens(pred_ids)\n", " hypo = ' '.join(translation)\n", " print(f\"{src} => {hypo} | reference: {tgt} | BLEU: {d2l.bleu(hypo, tgt, k=2):.3f}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python(titanic2_d2lbook_env)", "language": "python", "name": "titanic2_d2lbook_env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }