PlayTransformer/naive-transformer-time-series-test.ipynb
Sunday 9628f3a6a1 Add naive-transformer-time-series-test.ipynb
Add: `naive-transformer-time-series-test.ipynb` for demonstrating simple Transformer model with torch.nn.Transformer.
2024-04-24 21:19:42 +08:00

441 lines
93 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Implement a simple Transformer\n",
"\n",
"This notebook implement a simple transformer model, trained on a randomly generated time series data."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_squared_error"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Generate synthetic time series data\n",
"def generate_time_series(n_samples, n_steps):\n",
" freq1, freq2, offset1, offset2 = np.random.rand(4, n_samples, 1)\n",
" time = np.linspace(0, 1, n_steps)\n",
" series = 0.5 * np.sin((time - offset1) * (freq1 * 10 + 10))\n",
" series += 0.2 * np.sin((time - offset2) * (freq2 * 20 + 20))\n",
" series += 0.1 * (np.random.rand(n_samples, n_steps) - 0.5)\n",
" return series.astype(np.float32)\n",
"\n",
"# Create a transformer-based model\n",
"class TransformerModel(nn.Module):\n",
" def __init__(self, input_size, output_size, n_layers, n_heads, hidden_size, dropout):\n",
" super(TransformerModel, self).__init__()\n",
" self.transformer = nn.Transformer(\n",
" d_model=input_size,\n",
" nhead=n_heads,\n",
" num_encoder_layers=n_layers,\n",
" num_decoder_layers=n_layers,\n",
" dim_feedforward=hidden_size,\n",
" dropout=dropout,\n",
" )\n",
" self.fc = nn.Linear(input_size, output_size)\n",
"\n",
" def forward(self, x):\n",
" x = self.transformer(x, x)\n",
" x = self.fc(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Prepare data\n",
"n_samples = 1000\n",
"n_steps = 513\n",
"series = generate_time_series(n_samples, n_steps)\n",
"scaler = StandardScaler()\n",
"series_scaled = scaler.fit_transform(series.T).T\n",
"X = series_scaled[:, :n_steps-1]\n",
"y = series_scaled[:, 1:]\n",
"\n",
"# Split data into train and test sets\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
"# Convert data to PyTorch tensors\n",
"X_train_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
"y_train_tensor = torch.tensor(y_train, dtype=torch.float32)\n",
"X_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n",
"y_test_tensor = torch.tensor(y_test, dtype=torch.float32)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/torch210/lib/python3.11/site-packages/torch/nn/modules/transformer.py:282: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
" warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n",
"/root/miniconda3/envs/torch210/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/200], Loss: 1.2826976776123047\n",
"Epoch [2/200], Loss: 1.0116225481033325\n",
"Epoch [3/200], Loss: 0.635532021522522\n",
"Epoch [4/200], Loss: 0.4489162862300873\n",
"Epoch [5/200], Loss: 0.35763952136039734\n",
"Epoch [6/200], Loss: 0.2987120449542999\n",
"Epoch [7/200], Loss: 0.30284008383750916\n",
"Epoch [8/200], Loss: 0.2880264222621918\n",
"Epoch [9/200], Loss: 0.29169321060180664\n",
"Epoch [10/200], Loss: 0.2669679522514343\n",
"Epoch [11/200], Loss: 0.2674340307712555\n",
"Epoch [12/200], Loss: 0.25883418321609497\n",
"Epoch [13/200], Loss: 0.23866242170333862\n",
"Epoch [14/200], Loss: 0.21799877285957336\n",
"Epoch [15/200], Loss: 0.20396709442138672\n",
"Epoch [16/200], Loss: 0.19301307201385498\n",
"Epoch [17/200], Loss: 0.17444251477718353\n",
"Epoch [18/200], Loss: 0.16921290755271912\n",
"Epoch [19/200], Loss: 0.1575583964586258\n",
"Epoch [20/200], Loss: 0.14533430337905884\n",
"Epoch [21/200], Loss: 0.1386234015226364\n",
"Epoch [22/200], Loss: 0.12826502323150635\n",
"Epoch [23/200], Loss: 0.11779733747243881\n",
"Epoch [24/200], Loss: 0.11164677143096924\n",
"Epoch [25/200], Loss: 0.10296566784381866\n",
"Epoch [26/200], Loss: 0.09518710523843765\n",
"Epoch [27/200], Loss: 0.08982985466718674\n",
"Epoch [28/200], Loss: 0.08321435749530792\n",
"Epoch [29/200], Loss: 0.07733117789030075\n",
"Epoch [30/200], Loss: 0.07184430211782455\n",
"Epoch [31/200], Loss: 0.0663248598575592\n",
"Epoch [32/200], Loss: 0.06211206316947937\n",
"Epoch [33/200], Loss: 0.0579068660736084\n",
"Epoch [34/200], Loss: 0.054606616497039795\n",
"Epoch [35/200], Loss: 0.051191166043281555\n",
"Epoch [36/200], Loss: 0.047771405428647995\n",
"Epoch [37/200], Loss: 0.045270729809999466\n",
"Epoch [38/200], Loss: 0.04388266056776047\n",
"Epoch [39/200], Loss: 0.04167153686285019\n",
"Epoch [40/200], Loss: 0.03903637081384659\n",
"Epoch [41/200], Loss: 0.03716022148728371\n",
"Epoch [42/200], Loss: 0.035668425261974335\n",
"Epoch [43/200], Loss: 0.03426579385995865\n",
"Epoch [44/200], Loss: 0.0328022837638855\n",
"Epoch [45/200], Loss: 0.03179449960589409\n",
"Epoch [46/200], Loss: 0.030734049156308174\n",
"Epoch [47/200], Loss: 0.029527556151151657\n",
"Epoch [48/200], Loss: 0.028938110917806625\n",
"Epoch [49/200], Loss: 0.027805780991911888\n",
"Epoch [50/200], Loss: 0.02714334987103939\n",
"Epoch [51/200], Loss: 0.026013001799583435\n",
"Epoch [52/200], Loss: 0.026148412376642227\n",
"Epoch [53/200], Loss: 0.02501417137682438\n",
"Epoch [54/200], Loss: 0.024782532826066017\n",
"Epoch [55/200], Loss: 0.023976940661668777\n",
"Epoch [56/200], Loss: 0.02344522438943386\n",
"Epoch [57/200], Loss: 0.02298690192401409\n",
"Epoch [58/200], Loss: 0.022800864651799202\n",
"Epoch [59/200], Loss: 0.02203877829015255\n",
"Epoch [60/200], Loss: 0.02186082862317562\n",
"Epoch [61/200], Loss: 0.02145547792315483\n",
"Epoch [62/200], Loss: 0.02111400105059147\n",
"Epoch [63/200], Loss: 0.021111004054546356\n",
"Epoch [64/200], Loss: 0.021235451102256775\n",
"Epoch [65/200], Loss: 0.020454218611121178\n",
"Epoch [66/200], Loss: 0.020284976810216904\n",
"Epoch [67/200], Loss: 0.019960373640060425\n",
"Epoch [68/200], Loss: 0.019969569519162178\n",
"Epoch [69/200], Loss: 0.019146228209137917\n",
"Epoch [70/200], Loss: 0.01951329968869686\n",
"Epoch [71/200], Loss: 0.01905273087322712\n",
"Epoch [72/200], Loss: 0.018939871340990067\n",
"Epoch [73/200], Loss: 0.018563639372587204\n",
"Epoch [74/200], Loss: 0.018615636974573135\n",
"Epoch [75/200], Loss: 0.018304886296391487\n",
"Epoch [76/200], Loss: 0.01797996088862419\n",
"Epoch [77/200], Loss: 0.017887452617287636\n",
"Epoch [78/200], Loss: 0.017940785735845566\n",
"Epoch [79/200], Loss: 0.017486266791820526\n",
"Epoch [80/200], Loss: 0.01746840588748455\n",
"Epoch [81/200], Loss: 0.017102232202887535\n",
"Epoch [82/200], Loss: 0.017105089500546455\n",
"Epoch [83/200], Loss: 0.016733380034565926\n",
"Epoch [84/200], Loss: 0.016671737655997276\n",
"Epoch [85/200], Loss: 0.016277296468615532\n",
"Epoch [86/200], Loss: 0.01615353487432003\n",
"Epoch [87/200], Loss: 0.016070151701569557\n",
"Epoch [88/200], Loss: 0.015822922810912132\n",
"Epoch [89/200], Loss: 0.015388056635856628\n",
"Epoch [90/200], Loss: 0.015657253563404083\n",
"Epoch [91/200], Loss: 0.015209450386464596\n",
"Epoch [92/200], Loss: 0.01531938835978508\n",
"Epoch [93/200], Loss: 0.015104217454791069\n",
"Epoch [94/200], Loss: 0.015498699620366096\n",
"Epoch [95/200], Loss: 0.01536652259528637\n",
"Epoch [96/200], Loss: 0.015862323343753815\n",
"Epoch [97/200], Loss: 0.015583674423396587\n",
"Epoch [98/200], Loss: 0.014587637968361378\n",
"Epoch [99/200], Loss: 0.014702104032039642\n",
"Epoch [100/200], Loss: 0.014520714059472084\n",
"Epoch [101/200], Loss: 0.014230665750801563\n",
"Epoch [102/200], Loss: 0.014032993465662003\n",
"Epoch [103/200], Loss: 0.014286665245890617\n",
"Epoch [104/200], Loss: 0.01409598346799612\n",
"Epoch [105/200], Loss: 0.013588094152510166\n",
"Epoch [106/200], Loss: 0.013453097082674503\n",
"Epoch [107/200], Loss: 0.013409607112407684\n",
"Epoch [108/200], Loss: 0.013149403035640717\n",
"Epoch [109/200], Loss: 0.013141301460564137\n",
"Epoch [110/200], Loss: 0.01283340621739626\n",
"Epoch [111/200], Loss: 0.012824492529034615\n",
"Epoch [112/200], Loss: 0.012342739850282669\n",
"Epoch [113/200], Loss: 0.012487290427088737\n",
"Epoch [114/200], Loss: 0.012282567098736763\n",
"Epoch [115/200], Loss: 0.012123352847993374\n",
"Epoch [116/200], Loss: 0.012309921905398369\n",
"Epoch [117/200], Loss: 0.011849000118672848\n",
"Epoch [118/200], Loss: 0.011969990096986294\n",
"Epoch [119/200], Loss: 0.011744375340640545\n",
"Epoch [120/200], Loss: 0.011610714718699455\n",
"Epoch [121/200], Loss: 0.011537344194948673\n",
"Epoch [122/200], Loss: 0.011500836350023746\n",
"Epoch [123/200], Loss: 0.011424414813518524\n",
"Epoch [124/200], Loss: 0.011428999714553356\n",
"Epoch [125/200], Loss: 0.011216061189770699\n",
"Epoch [126/200], Loss: 0.011091525666415691\n",
"Epoch [127/200], Loss: 0.011024504899978638\n",
"Epoch [128/200], Loss: 0.011121418327093124\n",
"Epoch [129/200], Loss: 0.010962490923702717\n",
"Epoch [130/200], Loss: 0.010953482240438461\n",
"Epoch [131/200], Loss: 0.01096632331609726\n",
"Epoch [132/200], Loss: 0.01100214198231697\n",
"Epoch [133/200], Loss: 0.010964350774884224\n",
"Epoch [134/200], Loss: 0.010790727101266384\n",
"Epoch [135/200], Loss: 0.010693101212382317\n",
"Epoch [136/200], Loss: 0.010713927447795868\n",
"Epoch [137/200], Loss: 0.010633900761604309\n",
"Epoch [138/200], Loss: 0.010555696673691273\n",
"Epoch [139/200], Loss: 0.010487300343811512\n",
"Epoch [140/200], Loss: 0.010380434803664684\n",
"Epoch [141/200], Loss: 0.010355514474213123\n",
"Epoch [142/200], Loss: 0.010410982184112072\n",
"Epoch [143/200], Loss: 0.010351761244237423\n",
"Epoch [144/200], Loss: 0.010418400168418884\n",
"Epoch [145/200], Loss: 0.010332201607525349\n",
"Epoch [146/200], Loss: 0.010343995876610279\n",
"Epoch [147/200], Loss: 0.010427448898553848\n",
"Epoch [148/200], Loss: 0.010323477908968925\n",
"Epoch [149/200], Loss: 0.010058051906526089\n",
"Epoch [150/200], Loss: 0.010284533724188805\n",
"Epoch [151/200], Loss: 0.010071534663438797\n",
"Epoch [152/200], Loss: 0.009985437616705894\n",
"Epoch [153/200], Loss: 0.010022484697401524\n",
"Epoch [154/200], Loss: 0.009919346310198307\n",
"Epoch [155/200], Loss: 0.00982485618442297\n",
"Epoch [156/200], Loss: 0.009892821311950684\n",
"Epoch [157/200], Loss: 0.009983483701944351\n",
"Epoch [158/200], Loss: 0.009957253001630306\n",
"Epoch [159/200], Loss: 0.009947950020432472\n",
"Epoch [160/200], Loss: 0.0098407082259655\n",
"Epoch [161/200], Loss: 0.009712287224829197\n",
"Epoch [162/200], Loss: 0.009856811724603176\n",
"Epoch [163/200], Loss: 0.009634727612137794\n",
"Epoch [164/200], Loss: 0.009648673236370087\n",
"Epoch [165/200], Loss: 0.009548823349177837\n",
"Epoch [166/200], Loss: 0.009575641714036465\n",
"Epoch [167/200], Loss: 0.009495118632912636\n",
"Epoch [168/200], Loss: 0.009415577165782452\n",
"Epoch [169/200], Loss: 0.009466130286455154\n",
"Epoch [170/200], Loss: 0.009326552972197533\n",
"Epoch [171/200], Loss: 0.009429184719920158\n",
"Epoch [172/200], Loss: 0.009482380002737045\n",
"Epoch [173/200], Loss: 0.009315679781138897\n",
"Epoch [174/200], Loss: 0.00941756833344698\n",
"Epoch [175/200], Loss: 0.009411154314875603\n",
"Epoch [176/200], Loss: 0.009331178851425648\n",
"Epoch [177/200], Loss: 0.009307712316513062\n",
"Epoch [178/200], Loss: 0.009414409287273884\n",
"Epoch [179/200], Loss: 0.009246723726391792\n",
"Epoch [180/200], Loss: 0.009288187138736248\n",
"Epoch [181/200], Loss: 0.009330397471785545\n",
"Epoch [182/200], Loss: 0.009157110005617142\n",
"Epoch [183/200], Loss: 0.009136127308011055\n",
"Epoch [184/200], Loss: 0.009142176248133183\n",
"Epoch [185/200], Loss: 0.009172348305583\n",
"Epoch [186/200], Loss: 0.009076423943042755\n",
"Epoch [187/200], Loss: 0.009237755089998245\n",
"Epoch [188/200], Loss: 0.008973689749836922\n",
"Epoch [189/200], Loss: 0.009201562032103539\n",
"Epoch [190/200], Loss: 0.008966690860688686\n",
"Epoch [191/200], Loss: 0.009020267985761166\n",
"Epoch [192/200], Loss: 0.008946004323661327\n",
"Epoch [193/200], Loss: 0.00903538428246975\n",
"Epoch [194/200], Loss: 0.00894266925752163\n",
"Epoch [195/200], Loss: 0.008976456709206104\n",
"Epoch [196/200], Loss: 0.008941787295043468\n",
"Epoch [197/200], Loss: 0.00887981429696083\n",
"Epoch [198/200], Loss: 0.008928196504712105\n",
"Epoch [199/200], Loss: 0.008890601806342602\n",
"Epoch [200/200], Loss: 0.008821516297757626\n"
]
}
],
"source": [
"# Initialize model\n",
"input_size = 512\n",
"output_size = 512\n",
"n_layers = 2\n",
"n_heads = 4\n",
"hidden_size = 32\n",
"dropout = 0.1\n",
"model = TransformerModel(input_size, output_size, n_layers, n_heads, hidden_size, dropout).to(device)\n",
"\n",
"\n",
"# Convert data to the appropriate device\n",
"X_train_tensor = X_train_tensor.to(device)\n",
"y_train_tensor = y_train_tensor.to(device)\n",
"X_test_tensor = X_test_tensor.to(device)\n",
"y_test_tensor = y_test_tensor.to(device)\n",
"\n",
"\n",
"# Define loss function and optimizer\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"# Train the model\n",
"n_epochs = 200\n",
"for epoch in range(n_epochs):\n",
" optimizer.zero_grad()\n",
" y_pred = model(X_train_tensor)\n",
" loss = criterion(y_pred, y_train_tensor)\n",
" loss.backward()\n",
" optimizer.step()\n",
" print(f\"Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item()}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test MSE: 0.009519089013338089\n"
]
}
],
"source": [
"# Evaluate the model\n",
"with torch.no_grad():\n",
" model.eval()\n",
" y_pred_test = model(X_test_tensor)\n",
" y_pred_test = y_pred_test.cpu()\n",
" mse = mean_squared_error(y_test, y_pred_test)\n",
" print(f\"Test MSE: {mse}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot predictions\n",
"plt.figure(figsize=(10, 6))\n",
"plt.plot(y_test[0], label=\"True\")\n",
"plt.plot(y_pred_test[0], label=\"Predicted\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Comment\n",
"\n",
"Result over-fit, but it works."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch210",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}