{ "cells": [ { "cell_type": "markdown", "id": "4e5a7aa2", "metadata": {}, "source": [ "# Transformer Decoder on GRU Encoder\n", "\n", "This notebook replaces the decoder with the full Transformer stack while keeping the encoder as a GRU, positioning the model one step away from the complete Transformer architecture.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "e23f181d", "metadata": {}, "outputs": [], "source": [ "import math\n", "import os\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "\n", "import importlib\n", "import hw7\n", "importlib.reload(hw7)\n", "from hw7 import *\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "184d40c0", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-12-04T14:12:45.186444\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.7, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from tsv_seq2seq_data import TSVSeq2SeqData\n", "# data_path = os.path.expanduser('~/Dropbox/CS6140/data/sentence_pairs_large.tsv')\n", "# data = TSVSeq2SeqData(\n", "# path=data_path,\n", "# batch_size=512,\n", "# num_steps=25,\n", "# min_freq=2,\n", "# val_frac=0.05,\n", "# test_frac=0.0,\n", "# sample_percent=1,\n", "# )\n", "\n", "data = d2l.MTFraEng(batch_size=128)\n", "\n", "embed_size = 256\n", "num_hiddens = 320 \n", "num_blks = 3 \n", "dropout = 0.35 \n", "ffn_num_hiddens = 1280 \n", "\n", "num_layers = 3\n", "num_heads = 8\n", "\n", "#encoder = BatchFirstGRUEncoder(len(data.src_vocab), embed_size, num_hiddens, num_layers, dropout)\n", "encoder = SelfAttentionAugmentedEncoder(len(data.src_vocab), embed_size, num_hiddens, num_layers,\n", " num_heads=num_heads, dropout=dropout)\n", "\n", "decoder = TransformerDecoder(len(data.tgt_vocab), num_hiddens,\n", " ffn_num_hiddens, num_heads,\n", " num_blks, dropout)\n", "\n", "model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab[''], lr=0.0005)\n", "trainer = d2l.Trainer(max_epochs=15, gradient_clip_val=1, num_gpus=1)\n", "trainer.fit(model, data)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d1c50770", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "vamos . => let's go together. | reference: go . | BLEU: 0.000\n", "me perdi . => i was disappointed with me. | reference: i got lost . | BLEU: 0.000\n", "esta tranquilo . => this sentence is broken. | reference: he is calm . | BLEU: 0.000\n", "estoy en casa . => i'm at home right now. | reference: i am at home . | BLEU: 0.447\n", "donde esta el tren ? => where the train will be repaired | reference: where is the train ? | BLEU: 0.473\n", "necesito ayuda urgente . => i need help help help help help you help me any more. | reference: i need urgent help . | BLEU: 0.275\n", "ayer llovio mucho en la ciudad . => yesterday we spent a lot of snow in the city yesterday. | reference: it rained a lot in the city yesterday . | BLEU: 0.547\n", "los ninos estan jugando en el parque . => the playing in the park in the park during the park. | reference: the children are playing in the park . | BLEU: 0.439\n", "ella quiere aprender a hablar ingles muy bien . => she wants to learn to learn to speak ill very good english well. | reference: she wants to learn to speak english very well . | BLEU: 0.630\n", "cuando llegara el proximo tren a madrid ? => when i arrived, the train to madrid to madrid to madrid | reference: when will the next train to madrid arrive ? | BLEU: 0.422\n" ] } ], "source": [ "# examples = ['necesito ayuda urgente .', 'ayer llovio mucho en la ciudad .', 'los ninos estan jugando en el parque .', 'ella quiere aprender a hablar ingles muy bien .', 'cuando llegara el proximo tren a madrid ?']\n", "# references = ['i need urgent help .', 'it rained a lot in the city yesterday .', 'the children are playing in the park .', 'she wants to learn to speak english very well .', 'when will the next train to madrid arrive ?']\n", "\n", "# preds, _ = model.predict_step(\n", "# data.build(examples, references), d2l.try_gpu(), data.num_steps)\n", "# for src, tgt, pred in zip(examples, references, preds):\n", "# translation = []\n", "# for token in data.tgt_vocab.to_tokens(pred):\n", "# if token == '':\n", "# break\n", "# translation.append(token)\n", "# print(f\"{src} => {' '.join(translation)} | reference: {tgt}\")\n", "\n", "\n", "examples = ['vamos .', 'me perdi .', 'esta tranquilo .', 'estoy en casa .', 'donde esta el tren ?', 'necesito ayuda urgente .',\n", " 'ayer llovio mucho en la ciudad .', 'los ninos estan jugando en el parque .', 'ella quiere aprender a hablar ingles muy bien .',\n", " 'cuando llegara el proximo tren a madrid ?']\n", "\n", "references = ['go .', 'i got lost .', 'he is calm .', 'i am at home .', 'where is the train ?',\n", " 'i need urgent help .', 'it rained a lot in the city yesterday .',\n", " 'the children are playing in the park .', 'she wants to learn to speak english very well .', 'when will the next train to madrid arrive ?']\n", "\n", "preds, _ = model.predict_step(\n", " data.build(examples, references), d2l.try_gpu(), data.num_steps)\n", "for src, tgt, pred in zip(examples, references, preds):\n", " translation = []\n", " for token in data.tgt_vocab.to_tokens(pred):\n", " if token == '':\n", " break\n", " translation.append(token)\n", " \n", " hypo = ' '.join(translation)\n", " print(f\"{src} => {hypo} | reference: {tgt} | BLEU: {d2l.bleu(hypo, tgt, k=2):.3f}\")\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "00a0de84", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "vamos . => we're we're we're we're we're we're we're on we're we're we're we're we're we're we're we're we're we're we're we're we're we're on our we're on our we're on our we're on our we're on our we're on our we're | reference: go . | BLEU: 0.000\n", "me perdi . => i was taking my to kill me. i moved to kill me. i moved to kill me. i moved to me. i moved to me. i moved to me. i moved to me. i moved to me. myself. | reference: i got lost . | BLEU: 0.000\n", "esta tranquilo . => this christmas holiday this movie this christmas holiday this movie this movie this evening, this movie this movie this movie. let's read this movie. let's read this movie. let's this movie. this movie. this movie. this movie. this movie. this | reference: he is calm . | BLEU: 0.000\n", "estoy en casa . => i'm at home right now, i'm at home right now. i'm at home right now. i'm at home right now. i'm at home right now. i'm at home. i'm at home right now. i'm at home tomorrow. i'm at home. | reference: i am at home . | BLEU: 0.089\n", "donde esta el tren ? => where is the train will the train will the train will the train will the train will open where this train will open where this train will open the train will the train will the train will the train will | reference: where is the train ? | BLEU: 0.167\n", "necesito ayuda urgente . => i need to need help. i need help. i need help. i need help. i need help. i need help. i need help. i need help. i need help. i need help. others need help. others need help. others need | reference: i need urgent help . | BLEU: 0.089\n", "ayer llovio mucho en la ciudad . => yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday yesterday | reference: it rained a lot in the city yesterday . | BLEU: 0.000\n", "los ninos estan jugando en el parque . => played playing video games in the park playing in the park in the park the park park park park park park park park park park park park park park park park park park park park park park park park | reference: the children are playing in the park . | BLEU: 0.186\n", "ella quiere aprender a hablar ingles muy bien . => she really she wants to learn how to learn how to speak how to speak english very well. she can speak how to speak german very well. she can speak english very well. she can speak english very well. she | reference: she wants to learn to speak english very well . | BLEU: 0.280\n", "cuando llegara el proximo tren a madrid ? => when the the rang when the i got to madrid woke up when the i got to madrid to madrid to madrid to madrid to madrid to madrid to madrid to madrid to madrid to madrid | reference: when will the next train to madrid arrive ? | BLEU: 0.127\n" ] } ], "source": [ "\n", "\n", "\n", "\n", "examples = ['vamos .', 'me perdi .', 'esta tranquilo .', 'estoy en casa .', 'donde esta el tren ?', 'necesito ayuda urgente .',\n", " 'ayer llovio mucho en la ciudad .', 'los ninos estan jugando en el parque .', 'ella quiere aprender a hablar ingles muy bien .',\n", " 'cuando llegara el proximo tren a madrid ?']\n", "\n", "references = ['go .', 'i got lost .', 'he is calm .', 'i am at home .', 'where is the train ?',\n", " 'i need urgent help .', 'it rained a lot in the city yesterday .',\n", " 'the children are playing in the park .', 'she wants to learn to speak english very well .', 'when will the next train to madrid arrive ?']\n", "\n", "for src, tgt in zip(examples, references):\n", " src_sentence = src.lower().split()\n", " src_tokens = [data.src_vocab[token] for token in src_sentence]\n", " pred_ids = beam_search_translate(model, src_tokens, data, beam_size=5, max_steps=40)\n", " translation = data.tgt_vocab.to_tokens(pred_ids)\n", " hypo = ' '.join(translation)\n", " print(f\"{src} => {hypo} | reference: {tgt} | BLEU: {d2l.bleu(hypo, tgt, k=2):.3f}\")\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "93b77a91", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go . => ['va', '', '!'], bleu,0.000\n", "i lost . => ['je', 'suis', '', '.'], bleu,0.000\n", "he's calm . => ['il', 'est', '', '.'], bleu,0.658\n", "i'm home . => ['je', 'suis', '', '.'], bleu,0.512\n" ] } ], "source": [ "engs = ['go .', 'i lost .', \"he's calm .\", \"i'm home .\"] # Evaluation prompts\n", "fras = ['va !', \"j'ai perdu .\", 'il est calme .', 'je suis chez moi .'] # Ground-truth translations\n", "preds, _ = model.predict_step(\n", " data.build(engs, fras), d2l.try_gpu(), data.num_steps) # Batch translation on GPU\n", "for en, fr, p in zip(engs, fras, preds):\n", " translation = [] # Accumulate predicted tokens until EOS\n", " for token in data.tgt_vocab.to_tokens(p):\n", " if token == '':\n", " break # Stop once decoder predicts sequence end\n", " translation.append(token)\n", " print(f'{en} => {translation}, bleu,'\n", " f\"{d2l.bleu(' '.join(translation), fr, k=2):.3f}\") # Report BLEU per sentence\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2d47edfc", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python_mac_d2l", "language": "python", "name": "d2l" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.14.0" } }, "nbformat": 4, "nbformat_minor": 5 }