PlayTransformer/naive-transformer-time-series-test.ipynb

441 lines
93 KiB
Plaintext
Raw Permalink Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Implement a simple Transformer\n",
"\n",
"This notebook implement a simple transformer model, trained on a randomly generated time series data."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_squared_error"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Generate synthetic time series data\n",
"def generate_time_series(n_samples, n_steps):\n",
" freq1, freq2, offset1, offset2 = np.random.rand(4, n_samples, 1)\n",
" time = np.linspace(0, 1, n_steps)\n",
" series = 0.5 * np.sin((time - offset1) * (freq1 * 10 + 10))\n",
" series += 0.2 * np.sin((time - offset2) * (freq2 * 20 + 20))\n",
" series += 0.1 * (np.random.rand(n_samples, n_steps) - 0.5)\n",
" return series.astype(np.float32)\n",
"\n",
"# Create a transformer-based model\n",
"class TransformerModel(nn.Module):\n",
" def __init__(self, input_size, output_size, n_layers, n_heads, hidden_size, dropout):\n",
" super(TransformerModel, self).__init__()\n",
" self.transformer = nn.Transformer(\n",
" d_model=input_size,\n",
" nhead=n_heads,\n",
" num_encoder_layers=n_layers,\n",
" num_decoder_layers=n_layers,\n",
" dim_feedforward=hidden_size,\n",
" dropout=dropout,\n",
" )\n",
" self.fc = nn.Linear(input_size, output_size)\n",
"\n",
" def forward(self, x):\n",
" x = self.transformer(x, x)\n",
" x = self.fc(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Prepare data\n",
"n_samples = 1000\n",
"n_steps = 513\n",
"series = generate_time_series(n_samples, n_steps)\n",
"scaler = StandardScaler()\n",
"series_scaled = scaler.fit_transform(series.T).T\n",
"X = series_scaled[:, :n_steps-1]\n",
"y = series_scaled[:, 1:]\n",
"\n",
"# Split data into train and test sets\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
"# Convert data to PyTorch tensors\n",
"X_train_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
"y_train_tensor = torch.tensor(y_train, dtype=torch.float32)\n",
"X_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n",
"y_test_tensor = torch.tensor(y_test, dtype=torch.float32)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/torch210/lib/python3.11/site-packages/torch/nn/modules/transformer.py:282: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
" warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n",
"/root/miniconda3/envs/torch210/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/200], Loss: 1.2826976776123047\n",
"Epoch [2/200], Loss: 1.0116225481033325\n",
"Epoch [3/200], Loss: 0.635532021522522\n",
"Epoch [4/200], Loss: 0.4489162862300873\n",
"Epoch [5/200], Loss: 0.35763952136039734\n",
"Epoch [6/200], Loss: 0.2987120449542999\n",
"Epoch [7/200], Loss: 0.30284008383750916\n",
"Epoch [8/200], Loss: 0.2880264222621918\n",
"Epoch [9/200], Loss: 0.29169321060180664\n",
"Epoch [10/200], Loss: 0.2669679522514343\n",
"Epoch [11/200], Loss: 0.2674340307712555\n",
"Epoch [12/200], Loss: 0.25883418321609497\n",
"Epoch [13/200], Loss: 0.23866242170333862\n",
"Epoch [14/200], Loss: 0.21799877285957336\n",
"Epoch [15/200], Loss: 0.20396709442138672\n",
"Epoch [16/200], Loss: 0.19301307201385498\n",
"Epoch [17/200], Loss: 0.17444251477718353\n",
"Epoch [18/200], Loss: 0.16921290755271912\n",
"Epoch [19/200], Loss: 0.1575583964586258\n",
"Epoch [20/200], Loss: 0.14533430337905884\n",
"Epoch [21/200], Loss: 0.1386234015226364\n",
"Epoch [22/200], Loss: 0.12826502323150635\n",
"Epoch [23/200], Loss: 0.11779733747243881\n",
"Epoch [24/200], Loss: 0.11164677143096924\n",
"Epoch [25/200], Loss: 0.10296566784381866\n",
"Epoch [26/200], Loss: 0.09518710523843765\n",
"Epoch [27/200], Loss: 0.08982985466718674\n",
"Epoch [28/200], Loss: 0.08321435749530792\n",
"Epoch [29/200], Loss: 0.07733117789030075\n",
"Epoch [30/200], Loss: 0.07184430211782455\n",
"Epoch [31/200], Loss: 0.0663248598575592\n",
"Epoch [32/200], Loss: 0.06211206316947937\n",
"Epoch [33/200], Loss: 0.0579068660736084\n",
"Epoch [34/200], Loss: 0.054606616497039795\n",
"Epoch [35/200], Loss: 0.051191166043281555\n",
"Epoch [36/200], Loss: 0.047771405428647995\n",
"Epoch [37/200], Loss: 0.045270729809999466\n",
"Epoch [38/200], Loss: 0.04388266056776047\n",
"Epoch [39/200], Loss: 0.04167153686285019\n",
"Epoch [40/200], Loss: 0.03903637081384659\n",
"Epoch [41/200], Loss: 0.03716022148728371\n",
"Epoch [42/200], Loss: 0.035668425261974335\n",
"Epoch [43/200], Loss: 0.03426579385995865\n",
"Epoch [44/200], Loss: 0.0328022837638855\n",
"Epoch [45/200], Loss: 0.03179449960589409\n",
"Epoch [46/200], Loss: 0.030734049156308174\n",
"Epoch [47/200], Loss: 0.029527556151151657\n",
"Epoch [48/200], Loss: 0.028938110917806625\n",
"Epoch [49/200], Loss: 0.027805780991911888\n",
"Epoch [50/200], Loss: 0.02714334987103939\n",
"Epoch [51/200], Loss: 0.026013001799583435\n",
"Epoch [52/200], Loss: 0.026148412376642227\n",
"Epoch [53/200], Loss: 0.02501417137682438\n",
"Epoch [54/200], Loss: 0.024782532826066017\n",
"Epoch [55/200], Loss: 0.023976940661668777\n",
"Epoch [56/200], Loss: 0.02344522438943386\n",
"Epoch [57/200], Loss: 0.02298690192401409\n",
"Epoch [58/200], Loss: 0.022800864651799202\n",
"Epoch [59/200], Loss: 0.02203877829015255\n",
"Epoch [60/200], Loss: 0.02186082862317562\n",
"Epoch [61/200], Loss: 0.02145547792315483\n",
"Epoch [62/200], Loss: 0.02111400105059147\n",
"Epoch [63/200], Loss: 0.021111004054546356\n",
"Epoch [64/200], Loss: 0.021235451102256775\n",
"Epoch [65/200], Loss: 0.020454218611121178\n",
"Epoch [66/200], Loss: 0.020284976810216904\n",
"Epoch [67/200], Loss: 0.019960373640060425\n",
"Epoch [68/200], Loss: 0.019969569519162178\n",
"Epoch [69/200], Loss: 0.019146228209137917\n",
"Epoch [70/200], Loss: 0.01951329968869686\n",
"Epoch [71/200], Loss: 0.01905273087322712\n",
"Epoch [72/200], Loss: 0.018939871340990067\n",
"Epoch [73/200], Loss: 0.018563639372587204\n",
"Epoch [74/200], Loss: 0.018615636974573135\n",
"Epoch [75/200], Loss: 0.018304886296391487\n",
"Epoch [76/200], Loss: 0.01797996088862419\n",
"Epoch [77/200], Loss: 0.017887452617287636\n",
"Epoch [78/200], Loss: 0.017940785735845566\n",
"Epoch [79/200], Loss: 0.017486266791820526\n",
"Epoch [80/200], Loss: 0.01746840588748455\n",
"Epoch [81/200], Loss: 0.017102232202887535\n",
"Epoch [82/200], Loss: 0.017105089500546455\n",
"Epoch [83/200], Loss: 0.016733380034565926\n",
"Epoch [84/200], Loss: 0.016671737655997276\n",
"Epoch [85/200], Loss: 0.016277296468615532\n",
"Epoch [86/200], Loss: 0.01615353487432003\n",
"Epoch [87/200], Loss: 0.016070151701569557\n",
"Epoch [88/200], Loss: 0.015822922810912132\n",
"Epoch [89/200], Loss: 0.015388056635856628\n",
"Epoch [90/200], Loss: 0.015657253563404083\n",
"Epoch [91/200], Loss: 0.015209450386464596\n",
"Epoch [92/200], Loss: 0.01531938835978508\n",
"Epoch [93/200], Loss: 0.015104217454791069\n",
"Epoch [94/200], Loss: 0.015498699620366096\n",
"Epoch [95/200], Loss: 0.01536652259528637\n",
"Epoch [96/200], Loss: 0.015862323343753815\n",
"Epoch [97/200], Loss: 0.015583674423396587\n",
"Epoch [98/200], Loss: 0.014587637968361378\n",
"Epoch [99/200], Loss: 0.014702104032039642\n",
"Epoch [100/200], Loss: 0.014520714059472084\n",
"Epoch [101/200], Loss: 0.014230665750801563\n",
"Epoch [102/200], Loss: 0.014032993465662003\n",
"Epoch [103/200], Loss: 0.014286665245890617\n",
"Epoch [104/200], Loss: 0.01409598346799612\n",
"Epoch [105/200], Loss: 0.013588094152510166\n",
"Epoch [106/200], Loss: 0.013453097082674503\n",
"Epoch [107/200], Loss: 0.013409607112407684\n",
"Epoch [108/200], Loss: 0.013149403035640717\n",
"Epoch [109/200], Loss: 0.013141301460564137\n",
"Epoch [110/200], Loss: 0.01283340621739626\n",
"Epoch [111/200], Loss: 0.012824492529034615\n",
"Epoch [112/200], Loss: 0.012342739850282669\n",
"Epoch [113/200], Loss: 0.012487290427088737\n",
"Epoch [114/200], Loss: 0.012282567098736763\n",
"Epoch [115/200], Loss: 0.012123352847993374\n",
"Epoch [116/200], Loss: 0.012309921905398369\n",
"Epoch [117/200], Loss: 0.011849000118672848\n",
"Epoch [118/200], Loss: 0.011969990096986294\n",
"Epoch [119/200], Loss: 0.011744375340640545\n",
"Epoch [120/200], Loss: 0.011610714718699455\n",
"Epoch [121/200], Loss: 0.011537344194948673\n",
"Epoch [122/200], Loss: 0.011500836350023746\n",
"Epoch [123/200], Loss: 0.011424414813518524\n",
"Epoch [124/200], Loss: 0.011428999714553356\n",
"Epoch [125/200], Loss: 0.011216061189770699\n",
"Epoch [126/200], Loss: 0.011091525666415691\n",
"Epoch [127/200], Loss: 0.011024504899978638\n",
"Epoch [128/200], Loss: 0.011121418327093124\n",
"Epoch [129/200], Loss: 0.010962490923702717\n",
"Epoch [130/200], Loss: 0.010953482240438461\n",
"Epoch [131/200], Loss: 0.01096632331609726\n",
"Epoch [132/200], Loss: 0.01100214198231697\n",
"Epoch [133/200], Loss: 0.010964350774884224\n",
"Epoch [134/200], Loss: 0.010790727101266384\n",
"Epoch [135/200], Loss: 0.010693101212382317\n",
"Epoch [136/200], Loss: 0.010713927447795868\n",
"Epoch [137/200], Loss: 0.010633900761604309\n",
"Epoch [138/200], Loss: 0.010555696673691273\n",
"Epoch [139/200], Loss: 0.010487300343811512\n",
"Epoch [140/200], Loss: 0.010380434803664684\n",
"Epoch [141/200], Loss: 0.010355514474213123\n",
"Epoch [142/200], Loss: 0.010410982184112072\n",
"Epoch [143/200], Loss: 0.010351761244237423\n",
"Epoch [144/200], Loss: 0.010418400168418884\n",
"Epoch [145/200], Loss: 0.010332201607525349\n",
"Epoch [146/200], Loss: 0.010343995876610279\n",
"Epoch [147/200], Loss: 0.010427448898553848\n",
"Epoch [148/200], Loss: 0.010323477908968925\n",
"Epoch [149/200], Loss: 0.010058051906526089\n",
"Epoch [150/200], Loss: 0.010284533724188805\n",
"Epoch [151/200], Loss: 0.010071534663438797\n",
"Epoch [152/200], Loss: 0.009985437616705894\n",
"Epoch [153/200], Loss: 0.010022484697401524\n",
"Epoch [154/200], Loss: 0.009919346310198307\n",
"Epoch [155/200], Loss: 0.00982485618442297\n",
"Epoch [156/200], Loss: 0.009892821311950684\n",
"Epoch [157/200], Loss: 0.009983483701944351\n",
"Epoch [158/200], Loss: 0.009957253001630306\n",
"Epoch [159/200], Loss: 0.009947950020432472\n",
"Epoch [160/200], Loss: 0.0098407082259655\n",
"Epoch [161/200], Loss: 0.009712287224829197\n",
"Epoch [162/200], Loss: 0.009856811724603176\n",
"Epoch [163/200], Loss: 0.009634727612137794\n",
"Epoch [164/200], Loss: 0.009648673236370087\n",
"Epoch [165/200], Loss: 0.009548823349177837\n",
"Epoch [166/200], Loss: 0.009575641714036465\n",
"Epoch [167/200], Loss: 0.009495118632912636\n",
"Epoch [168/200], Loss: 0.009415577165782452\n",
"Epoch [169/200], Loss: 0.009466130286455154\n",
"Epoch [170/200], Loss: 0.009326552972197533\n",
"Epoch [171/200], Loss: 0.009429184719920158\n",
"Epoch [172/200], Loss: 0.009482380002737045\n",
"Epoch [173/200], Loss: 0.009315679781138897\n",
"Epoch [174/200], Loss: 0.00941756833344698\n",
"Epoch [175/200], Loss: 0.009411154314875603\n",
"Epoch [176/200], Loss: 0.009331178851425648\n",
"Epoch [177/200], Loss: 0.009307712316513062\n",
"Epoch [178/200], Loss: 0.009414409287273884\n",
"Epoch [179/200], Loss: 0.009246723726391792\n",
"Epoch [180/200], Loss: 0.009288187138736248\n",
"Epoch [181/200], Loss: 0.009330397471785545\n",
"Epoch [182/200], Loss: 0.009157110005617142\n",
"Epoch [183/200], Loss: 0.009136127308011055\n",
"Epoch [184/200], Loss: 0.009142176248133183\n",
"Epoch [185/200], Loss: 0.009172348305583\n",
"Epoch [186/200], Loss: 0.009076423943042755\n",
"Epoch [187/200], Loss: 0.009237755089998245\n",
"Epoch [188/200], Loss: 0.008973689749836922\n",
"Epoch [189/200], Loss: 0.009201562032103539\n",
"Epoch [190/200], Loss: 0.008966690860688686\n",
"Epoch [191/200], Loss: 0.009020267985761166\n",
"Epoch [192/200], Loss: 0.008946004323661327\n",
"Epoch [193/200], Loss: 0.00903538428246975\n",
"Epoch [194/200], Loss: 0.00894266925752163\n",
"Epoch [195/200], Loss: 0.008976456709206104\n",
"Epoch [196/200], Loss: 0.008941787295043468\n",
"Epoch [197/200], Loss: 0.00887981429696083\n",
"Epoch [198/200], Loss: 0.008928196504712105\n",
"Epoch [199/200], Loss: 0.008890601806342602\n",
"Epoch [200/200], Loss: 0.008821516297757626\n"
]
}
],
"source": [
"# Initialize model\n",
"input_size = 512\n",
"output_size = 512\n",
"n_layers = 2\n",
"n_heads = 4\n",
"hidden_size = 32\n",
"dropout = 0.1\n",
"model = TransformerModel(input_size, output_size, n_layers, n_heads, hidden_size, dropout).to(device)\n",
"\n",
"\n",
"# Convert data to the appropriate device\n",
"X_train_tensor = X_train_tensor.to(device)\n",
"y_train_tensor = y_train_tensor.to(device)\n",
"X_test_tensor = X_test_tensor.to(device)\n",
"y_test_tensor = y_test_tensor.to(device)\n",
"\n",
"\n",
"# Define loss function and optimizer\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"# Train the model\n",
"n_epochs = 200\n",
"for epoch in range(n_epochs):\n",
" optimizer.zero_grad()\n",
" y_pred = model(X_train_tensor)\n",
" loss = criterion(y_pred, y_train_tensor)\n",
" loss.backward()\n",
" optimizer.step()\n",
" print(f\"Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item()}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test MSE: 0.009519089013338089\n"
]
}
],
"source": [
"# Evaluate the model\n",
"with torch.no_grad():\n",
" model.eval()\n",
" y_pred_test = model(X_test_tensor)\n",
" y_pred_test = y_pred_test.cpu()\n",
" mse = mean_squared_error(y_test, y_pred_test)\n",
" print(f\"Test MSE: {mse}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0YAAAH5CAYAAAClJy6RAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAADhDUlEQVR4nOzdd5hdZbX48e8+fU6Z3pNJMimkEUIIJaFDKIK9IqDCFVTUa0Ou/gAFruViQUVEVERAxAKKokiX3kNCAqSQnslkej9zetu/P95T5sxMGszMPmV9nmeenN0mK+Wc2Wuv912vpuu6jhBCCCGEEEIUMZPRAQghhBBCCCGE0SQxEkIIIYQQQhQ9SYyEEEIIIYQQRU8SIyGEEEIIIUTRk8RICCGEEEIIUfQkMRJCCCGEEEIUPUmMhBBCCCGEEEXPYnQAEy2RSNDe3o7H40HTNKPDEUIIIYQQQhhE13WGh4dpbGzEZNp/TajgEqP29naampqMDkMIIYQQQgiRI1pbW5k+ffp+zym4xMjj8QDqD19aWmpwNEIIIYQQQgijeL1empqa0jnC/hRcYpQaPldaWiqJkRBCCCGEEOKgpthI8wUhhBBCCCFE0ZPESAghhBBCCFH0JDESQgghhBBCFL2Cm2MkhBBCCCHEoYjH40SjUaPDEG+T1WrFbDa/4+8jiZEQQgghhChKuq7T2dnJ4OCg0aGId6i8vJz6+vp3tI6pJEZCCCGEEKIopZKi2tpanE7nO7qpFsbQdZ1AIEB3dzcADQ0Nb/t7SWIkhBBCCCGKTjweTydFVVVVRocj3oGSkhIAuru7qa2tfdvD6qT5ghBCCCGEKDqpOUVOp9PgSMRESP07vpO5YpIYCSGEEEKIoiXD5wrDRPw7SmIkhBBCCCGEKHqSGAkhhBBCCCGKniRGQgghhBBCiKIniZEQQgghhBB5QNO0/X5dfPHFRoeY16RdtxBCCCGEEHmgo6Mj/fqee+7hmmuuYcuWLel9qbbVKdFoFKvVOmXx5TupGAkhhBBCCEFysdBIbMq/dF0/qPjq6+vTX2VlZWialt4OhUKUl5dz7733cuqpp+JwOLj77ru57rrrOPLII7O+z4033sisWbOy9t1xxx0sXLgQh8PBggULuOWWWybobzV/SMVICCGEEEIIIBiNs+iaR6f89930nbNx2ibmtvyb3/wmP/nJT7jjjjuw2+3ceuutB7zmt7/9Lddeey0333wzy5YtY926dXzmM5/B5XJx0UUXTUhc+UASIyGEEEIIIQrEV7/6VT70oQ8d0jXf/e53+clPfpK+rrm5mU2bNvGb3/xGEiMhhBBCCCGmVDwK//oyzD4Fln7ckBBKrGY2fedsQ37fiXL00Ucf0vk9PT20trZyySWX8JnPfCa9PxaLUVZWNmFx5QNJjIQQQgghhPFaV8Prf4I9LxmWGGmaNmFD2ozicrmytk0m05g5TNFoNP06kUgAajjdcccdl3We2TxxCVs+yO9/eSGEEEIIkfcefKODZx94lh8CBPuNDqeg1NTU0NnZia7raJoGwPr169PH6+rqmDZtGjt37uTCCy80KMrcIImREEKI/JJIwF/OB12H8/8CJmmwKkSui8QSBCIxyp22cY8/vKEDm38QbEDIq97n8t6eEKeeeio9PT386Ec/4iMf+QiPPPIIDz/8MKWlpelzrrvuOr785S9TWlrKOeecQzgcZs2aNQwMDHD55ZcbGP3Ukv9xQggh8kvXBtj6CGx7FIY7Dny+EMJwF9+xmhXXP0HPcHjc472+MKVaILmlQ9g7dcEVuIULF3LLLbfwy1/+kqVLl7J69WquuOKKrHMuvfRSbrvtNu68806WLFnCKaecwp133klzc7NBURtDKkZCCCHyy+7nMq+HO6Bs2sFdl4jDE/8LM0+Ew86anNiEEGNEYgle2dVPPKGztWuYGo99zDl9vggeApkdoUEoKc86ZygQpbTEkh4OVuwuvvhiLr744vT2rFmz9rke0mWXXcZll12Wte+qq67K2r7gggu44IILJjzOfCIVIyGEEPll14jEyNt+8Net/yO88HP400cnPiYhxD7t7vMTT6gb9v1VjDxaMLMjOJh1fGP7EMu++xjX/WvjZIUphCRGQggh8kgiDi0vZrZHDqXztsO/L4f+XeNf2/HG5MYmhBjX9m5f+nX3cGjM8Vg8wUAgOrZiNMLrrUMkdHijbWiywhRCEiMhhBB5pON1CI+4MfK2ZV4//zNY8zv677uczR3jzE8YdaMlhJgaIxOj8SpG/f4IAB5tZGKk3ufd3hCDgQh9PnWdPxxjU7uXL/95HS19/kmMWhQjSYyEEELkj9bVWZvxoRFD6VpfAaBs71P818/vH3vtyKE58ejY40KISbFtX4lRcADuvYjg5scAKB1ZMQoOEojE+PZPf86VP7+NvmTy5A/H+fSdr/Kv19u55PdrpiR+UTwkMRJCCJE/ercCsFevBkBvfx1uWQmPXAWdGwAwazofMz9DOBbPvjY4kHkdHp6ScIUQoypGvhGJ0Qs/h033M/PhTwJkzzEKDdLd0cov9R/ww9B32N2jqsC+cIxOb2jM9xViIkhiJIQQIn8kE6MX4ocDYOnfCt2b4OVfgp5JhM63PMnA4ED2tSPnI0krYCGmRDyhs6MnlcDouAa3QjymNgN96fPmaG3Zc4yCg0TaN2DREpRqQQbadwBqKJ0Qk0USIyGEEHkj3rMNgBcSi8c9/nj8KPbq1TRo/Zif/XHmQCyS3cFOKkZCTIm9AwEisQQA55uf5Fbfl+DFm9RBfyYxOsu0dtQco0ES3VvSm2XBPRyu7cSWCOKymZihdQH6PttTC/F2SGIkhBAiP4S8mP2dALy0j8To5cQiro1eBED1m7+FHlVhwrsXGHEDFZKKkRBToaVPJTs1HjsLtT0AJHa/oA4OZDpInmVeg4eRQ+mGsAxsT29eYH6Sf9u/xe+sN3C56S88a/8aHzQ9T68vMvl/CFE0JjUxevbZZ3nve99LY2MjmqZx//337/f8p59+Gk3Txny99dZbkxmmEEKIfNCnbpJ69DJ6KM8+Zi8F4JXEAp5ILOfp+FI0PQ5v3gt716Kv/1P2+SMrRlsegXs+AYH+SQxeiCLh7YA/XwA7ngRgIKASl9nVLmpNqtNconsz6DoM7E5ftsy0Hac2sjHDICVDO9Kb7zK/CsBK8yYu4X5ADZnd0y+d6SbTddddx5FHHpnevvjii/nABz4w5XHs3r0bTdNYv379pP4+k5oY+f1+li5dys0333xI123ZsoWOjo7017x58yYpQiGEEHmjVw2j26E3jj32ifv447RvsUGfDcA/48er/a/fA3eeizZyWB1kzzF6/mew+QHY8tBkRC1EcdnwN9jyILx0CwDeoOoAWemyUW9WDyQsw23QvxOiAdBMDJkqxnwbPTRIqW8fa5IlRXVLuiJVbC6++OJ0AcFqtTJ79myuuOIK/P7JTRR//vOfc+eddx7UuVOVzEwky2R+83POOYdzzjnnkK+rra2lvLx84gMSQgiRv/pUYrQz0YDHPuLHV1kTNB3LUzYN6MZtt/BEeBkJzJiG9oz/vUYmRoMt6tfhDoiGIB4BR+nk/BmEKHTJBxipNcYGAyoxKiuxUqMNZUa0bn1E/Vo6nRZ/GUckspul9LftoIrB/f5WM7RuVhdpYgTwrne9izvuuINoNMpzzz3HpZdeit/v51e/+lXWedFoFKvVOiG/Z1lZ2YR8n1yVk3OMli1bRkNDA6tWreKpp57a77nhcBiv15v1JYQQogAlO9Lt0BtZ0ODhf6KfxWethI/9HoCh5JPp2TUuvLjZ7jpy398rNccoFs50qxvuhN+eDjcdCZHivdkS4h1JDnlNJ0bJ92WZ00qFPpg+bcfzf1MvKmexO1E75tukkqKQnrmh9+pOzo9czTXJeYSNWi97ewfHXFss7HY79fX1NDU1ccEFF3DhhRdy//33p4e/3X777cyePRu73Y6u6wwNDfHZz36W2tpaSktLOf3003n99dezvucPfvAD6urq8Hg8XHLJJYRCoazjo4fSJRIJfvjDHzJ37lzsdjszZszg+9//PgDNzc2
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot predictions\n",
"plt.figure(figsize=(10, 6))\n",
"plt.plot(y_test[0], label=\"True\")\n",
"plt.plot(y_pred_test[0], label=\"Predicted\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Comment\n",
"\n",
"Result over-fit, but it works."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torch210",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}