From cbb77d3f335ccdb2b786aa79b0792e2f03185e0e Mon Sep 17 00:00:00 2001 From: AJYChen Date: Sun, 3 Oct 2021 17:59:38 +0800 Subject: [PATCH] Fix bugs --- .../Sunshine_LSTM-checkpoint.ipynb | 1960 +++++++++++++++++ Sunshine_LSTM.ipynb | 17 +- Sunshine_LSTM.py | 171 ++ 3 files changed, 2133 insertions(+), 15 deletions(-) create mode 100644 .ipynb_checkpoints/Sunshine_LSTM-checkpoint.ipynb create mode 100644 Sunshine_LSTM.py diff --git a/.ipynb_checkpoints/Sunshine_LSTM-checkpoint.ipynb b/.ipynb_checkpoints/Sunshine_LSTM-checkpoint.ipynb new file mode 100644 index 0000000..8c2d6ff --- /dev/null +++ b/.ipynb_checkpoints/Sunshine_LSTM-checkpoint.ipynb @@ -0,0 +1,1960 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import keras\n", + "import tensorflow as tf\n", + "from keras.preprocessing.sequence import TimeseriesGenerator\n", + "from keras.models import Sequential\n", + "from keras.layers import LSTM, Dense" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 7395 entries, 0 to 7394\n", + "Data columns (total 14 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 YEAR 7395 non-null int64 \n", + " 1 MO 7395 non-null int64 \n", + " 2 DY 7395 non-null int64 \n", + " 3 ALLSKY_SFC_SW_DWN 7395 non-null float64\n", + " 4 CLRSKY_SFC_SW_DWN 7395 non-null float64\n", + " 5 WS2M 7395 non-null float64\n", + " 6 ALLSKY_KT 7395 non-null float64\n", + " 7 ALLSKY_NKT 7395 non-null float64\n", + " 8 ALLSKY_SFC_LW_DWN 7395 non-null float64\n", + " 9 ALLSKY_SFC_PAR_TOT 7395 non-null float64\n", + " 10 CLRSKY_SFC_PAR_TOT 7395 non-null float64\n", + " 11 ALLSKY_SFC_UVA 7395 non-null float64\n", + " 12 ALLSKY_SFC_UVB 7395 non-null float64\n", + " 13 ALLSKY_SFC_UV_INDEX 7395 non-null float64\n", + "dtypes: float64(11), int64(3)\n", + "memory usage: 809.0 KB\n", + "None\n" + ] + } + ], + "source": [ + "filename = \"data.csv\"\n", + "df = pd.read_csv(filename)\n", + "print(df.info())" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
YEARMODYALLSKY_SFC_SW_DWNCLRSKY_SFC_SW_DWNWS2MALLSKY_KTALLSKY_NKTALLSKY_SFC_LW_DWNALLSKY_SFC_PAR_TOTCLRSKY_SFC_PAR_TOTALLSKY_SFC_UVAALLSKY_SFC_UVBALLSKY_SFC_UV_INDEX
02001113.965.109.250.560.75385.3774.9895.799.910.251.34
12001124.085.025.840.580.77387.0677.6195.2710.020.251.32
22001134.265.025.410.600.80374.7481.5996.2710.490.271.41
32001144.714.996.810.660.88373.5589.2394.9611.220.281.45
42001154.354.897.300.610.81386.3483.4593.8710.610.271.40
.............................................
739020213276.537.065.260.650.77411.87128.84137.5816.950.482.43
739120213286.586.893.330.660.78406.53128.20132.8416.440.472.43
739220213295.866.201.290.580.69400.70111.09116.5213.550.382.00
739320213306.036.092.190.590.71390.62112.07112.8613.340.38-999.00
739420213316.416.501.200.630.75387.44120.46121.5514.910.412.16
\n", + "

7395 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " YEAR MO DY ALLSKY_SFC_SW_DWN CLRSKY_SFC_SW_DWN WS2M ALLSKY_KT \\\n", + "0 2001 1 1 3.96 5.10 9.25 0.56 \n", + "1 2001 1 2 4.08 5.02 5.84 0.58 \n", + "2 2001 1 3 4.26 5.02 5.41 0.60 \n", + "3 2001 1 4 4.71 4.99 6.81 0.66 \n", + "4 2001 1 5 4.35 4.89 7.30 0.61 \n", + "... ... .. .. ... ... ... ... \n", + "7390 2021 3 27 6.53 7.06 5.26 0.65 \n", + "7391 2021 3 28 6.58 6.89 3.33 0.66 \n", + "7392 2021 3 29 5.86 6.20 1.29 0.58 \n", + "7393 2021 3 30 6.03 6.09 2.19 0.59 \n", + "7394 2021 3 31 6.41 6.50 1.20 0.63 \n", + "\n", + " ALLSKY_NKT ALLSKY_SFC_LW_DWN ALLSKY_SFC_PAR_TOT CLRSKY_SFC_PAR_TOT \\\n", + "0 0.75 385.37 74.98 95.79 \n", + "1 0.77 387.06 77.61 95.27 \n", + "2 0.80 374.74 81.59 96.27 \n", + "3 0.88 373.55 89.23 94.96 \n", + "4 0.81 386.34 83.45 93.87 \n", + "... ... ... ... ... \n", + "7390 0.77 411.87 128.84 137.58 \n", + "7391 0.78 406.53 128.20 132.84 \n", + "7392 0.69 400.70 111.09 116.52 \n", + "7393 0.71 390.62 112.07 112.86 \n", + "7394 0.75 387.44 120.46 121.55 \n", + "\n", + " ALLSKY_SFC_UVA ALLSKY_SFC_UVB ALLSKY_SFC_UV_INDEX \n", + "0 9.91 0.25 1.34 \n", + "1 10.02 0.25 1.32 \n", + "2 10.49 0.27 1.41 \n", + "3 11.22 0.28 1.45 \n", + "4 10.61 0.27 1.40 \n", + "... ... ... ... \n", + "7390 16.95 0.48 2.43 \n", + "7391 16.44 0.47 2.43 \n", + "7392 13.55 0.38 2.00 \n", + "7393 13.34 0.38 -999.00 \n", + "7394 14.91 0.41 2.16 \n", + "\n", + "[7395 rows x 14 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "df = df.drop(['YEAR', 'MO', 'DY','ALLSKY_SFC_UVA','ALLSKY_SFC_UVB','ALLSKY_SFC_SW_DWN','CLRSKY_SFC_SW_DWN','WS2M','ALLSKY_KT','ALLSKY_NKT','ALLSKY_SFC_LW_DWN','CLRSKY_SFC_PAR_TOT'], axis = 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ALLSKY_SFC_PAR_TOTALLSKY_SFC_UV_INDEX
074.981.34
177.611.32
281.591.41
389.231.45
483.451.40
.........
7390128.842.43
7391128.202.43
7392111.092.00
7393112.07-999.00
7394120.462.16
\n", + "

7395 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " ALLSKY_SFC_PAR_TOT ALLSKY_SFC_UV_INDEX\n", + "0 74.98 1.34\n", + "1 77.61 1.32\n", + "2 81.59 1.41\n", + "3 89.23 1.45\n", + "4 83.45 1.40\n", + "... ... ...\n", + "7390 128.84 2.43\n", + "7391 128.20 2.43\n", + "7392 111.09 2.00\n", + "7393 112.07 -999.00\n", + "7394 120.46 2.16\n", + "\n", + "[7395 rows x 2 columns]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "allskypar = df['ALLSKY_SFC_PAR_TOT'].values\n", + "allskypar = allskypar.reshape((-1,1))\n", + "df.insert(0, 'INDEX', range(1, 1 + len(df)))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7321\n", + "74\n" + ] + } + ], + "source": [ + "split_percent = 0.99\n", + "split = int(split_percent*len(allskypar))\n", + "\n", + "allskypar_train = allskypar[:split]\n", + "allskypar_test = allskypar[split:]\n", + "\n", + "date_train = df['INDEX'][:split]\n", + "date_test = df['INDEX'][split:]\n", + "\n", + "print(len(allskypar_train))\n", + "print(len(allskypar_test))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "look_back = 30\n", + "\n", + "train_generator = TimeseriesGenerator(allskypar_train, allskypar_train, length=look_back, batch_size=20) \n", + "test_generator = TimeseriesGenerator(allskypar_test, allskypar_test, length=look_back, batch_size=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 1086.8380\n", + "Epoch 2/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 577.1627\n", + "Epoch 3/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 548.5074\n", + "Epoch 4/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 549.1390\n", + "Epoch 5/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 526.5082\n", + "Epoch 6/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 534.1251\n", + "Epoch 7/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 527.6324\n", + "Epoch 8/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 534.5848\n", + "Epoch 9/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 523.3795\n", + "Epoch 10/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 521.5023\n", + "Epoch 11/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 519.7992\n", + "Epoch 12/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 1670170.3750\n", + "Epoch 13/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 3424.5903\n", + "Epoch 14/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 763.5213\n", + "Epoch 15/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 608.0161\n", + "Epoch 16/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 591.7673\n", + "Epoch 17/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 557.9049\n", + "Epoch 18/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 545.4752\n", + "Epoch 19/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 544.2651\n", + "Epoch 20/25\n", + "365/365 [==============================] - 4s 10ms/step - loss: 538.6804\n", + "Epoch 21/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 538.7037\n", + "Epoch 22/25\n", + "365/365 [==============================] - 4s 11ms/step - loss: 533.7980\n", + "Epoch 23/25\n", + "365/365 [==============================] - 4s 10ms/step - loss: 567.3826\n", + "Epoch 24/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 525.5330\n", + "Epoch 25/25\n", + "365/365 [==============================] - 3s 9ms/step - loss: 525.9817\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = Sequential()\n", + "model.add(\n", + " LSTM(100,\n", + " activation='relu',\n", + " input_shape=(look_back,1))\n", + ")\n", + "model.add(Dense(2))\n", + "model.compile(optimizer='adam', loss='mse')\n", + "\n", + "num_epochs = 25\n", + "model.fit(train_generator, epochs=num_epochs, verbose=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "mode": "lines", + "name": "Ground Truth", + "type": "scatter", + "x": [ + 7322, + 7323, + 7324, + 7325, + 7326, + 7327, + 7328, + 7329, + 7330, + 7331, + 7332, + 7333, + 7334, + 7335, + 7336, + 7337, + 7338, + 7339, + 7340, + 7341, + 7342, + 7343, + 7344, + 7345, + 7346, + 7347, + 7348, + 7349, + 7350, + 7351, + 7352, + 7353, + 7354, + 7355, + 7356, + 7357, + 7358, + 7359, + 7360, + 7361, + 7362, + 7363, + 7364, + 7365, + 7366, + 7367, + 7368, + 7369, + 7370, + 7371, + 7372, + 7373, + 7374, + 7375, + 7376, + 7377, + 7378, + 7379, + 7380, + 7381, + 7382, + 7383, + 7384, + 7385, + 7386, + 7387, + 7388, + 7389, + 7390, + 7391, + 7392, + 7393, + 7394, + 7395 + ], + "y": [ + 38.02, + 55.34, + 32.91, + 64.97, + 45.7, + 52.55, + 77.27, + 81.23, + 31.1, + 90.81, + 91.77, + 66.9, + 58.49, + 79.45, + 81.38, + 91.9, + 83.01, + 76.12, + 92.56, + 107.47, + 107.45, + 52.53, + 88.55, + 85.2, + 41.23, + 89.79, + 107.57, + 110.53, + 113.33, + 111.09, + 86.47, + 62.86, + 74.45, + 88.41, + 79.86, + 101.84, + 83.05, + 58.56, + 56.99, + 106.49, + 117.2, + 117.65, + 118.07, + 115.05, + 101.78, + 94.29, + 105.94, + 107.59, + 114.23, + 91.12, + 105.63, + 102.09, + 105.66, + 112.83, + 122.8, + 105.67, + 116.16, + 96.92, + 75.97, + 104.91, + 117.21, + 130.15, + 129.71, + 90.13, + 41.3, + 63.95, + 62.32, + 54.77, + 103.56, + 128.84, + 128.2, + 111.09, + 112.07, + 120.46 + ] + }, + { + "mode": "lines", + "name": "Prediction", + "type": "scatter", + "x": [ + 7322, + 7323, + 7324, + 7325, + 7326, + 7327, + 7328, + 7329, + 7330, + 7331, + 7332, + 7333, + 7334, + 7335, + 7336, + 7337, + 7338, + 7339, + 7340, + 7341, + 7342, + 7343, + 7344, + 7345, + 7346, + 7347, + 7348, + 7349, + 7350, + 7351, + 7352, + 7353, + 7354, + 7355, + 7356, + 7357, + 7358, + 7359, + 7360, + 7361, + 7362, + 7363, + 7364, + 7365, + 7366, + 7367, + 7368, + 7369, + 7370, + 7371, + 7372, + 7373, + 7374, + 7375, + 7376, + 7377, + 7378, + 7379, + 7380, + 7381, + 7382, + 7383, + 7384, + 7385, + 7386, + 7387, + 7388, + 7389, + 7390, + 7391, + 7392, + 7393, + 7394, + 7395 + ], + "y": [ + 98.83181762695312, + 99.81380462646484, + 83.80471801757812, + 83.92058563232422, + 72.02377319335938, + 71.81573486328125, + 78.9570083618164, + 79.8797607421875, + 83.42620849609375, + 83.29237365722656, + 76.66921997070312, + 76.23536682128906, + 94.64385223388672, + 94.6901626586914, + 81.48623657226562, + 80.6895980834961, + 70.16002655029297, + 73.18270111083984, + 69.88146209716797, + 67.95851135253906, + 96.52891540527344, + 98.48237609863281, + 107.59458923339844, + 108.63009643554688, + 105.25763702392578, + 105.21949005126953, + 106.46473693847656, + 106.51319885253906, + 106.04330444335938, + 104.49585723876953, + 96.79668426513672, + 96.9975814819336, + 93.816650390625, + 93.49519348144531, + 101.24125671386719, + 99.92263793945312, + 96.39584350585938, + 96.83317565917969, + 104.6543960571289, + 104.57909393310547, + 95.47126770019531, + 95.33518981933594, + 100.1097640991211, + 102.34142303466797, + 100.35750579833984, + 102.21908569335938, + 98.66891479492188, + 100.58856201171875, + 104.74982452392578, + 105.34442901611328, + 111.75166320800781, + 112.16352844238281, + 102.96778106689453, + 104.25030517578125, + 108.79991912841797, + 109.8946762084961, + 95.38335418701172, + 95.30912017822266, + 89.80103302001953, + 90.94747924804688, + 97.10368347167969, + 97.30387878417969, + 105.98876190185547, + 106.23373413085938, + 117.33837890625, + 118.23141479492188, + 113.98306274414062, + 114.24811553955078, + 97.7326889038086, + 98.3048095703125, + 74.45018768310547, + 75.55924224853516, + 85.17854309082031, + 86.841064453125, + 68.82337951660156, + 66.13966369628906, + 72.95000457763672, + 72.23511505126953, + 95.20930480957031, + 96.00569915771484, + 130.6641387939453, + 129.48353576660156, + 116.73051452636719, + 116.37247467041016, + 99.65167236328125, + 101.6120376586914, + 106.2371826171875, + 104.84210968017578 + ] + } + ], + "layout": { + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "LSTM: AllSkyPar" + }, + "xaxis": { + "title": { + "text": "Date" + } + }, + "yaxis": { + "title": { + "text": "allskypar" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(test_generator)\n", + "prediction = model.predict(test_generator)\n", + "\n", + "allskypar_train = allskypar_train.reshape((-1))\n", + "allskypar_test = allskypar_test.reshape((-1))\n", + "prediction = prediction.reshape((-1))\n", + "import plotly.graph_objects as go\n", + "\n", + "trace1 = go.Scatter(\n", + " x = date_train,\n", + " y = allskypar_train,\n", + " mode = 'lines',\n", + " name = 'Data'\n", + ")\n", + "trace2 = go.Scatter(\n", + " x = date_test,\n", + " y = prediction,\n", + " mode = 'lines',\n", + " name = 'Prediction'\n", + ")\n", + "trace3 = go.Scatter(\n", + " x = date_test,\n", + " y = allskypar_test,\n", + " mode='lines',\n", + " name = 'Ground Truth'\n", + ")\n", + "layout = go.Layout(\n", + " title = \"LSTM: AllSkyPar\",\n", + " xaxis = {'title' : \"Date\"},\n", + " yaxis = {'title' : \"allskypar\"}\n", + ")\n", + "fig = go.Figure(data=[trace3, trace2], layout=layout)\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:From C:\\Users\\wei_user\\.conda\\envs\\DLenv\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n", + "WARNING:tensorflow:From C:\\Users\\wei_user\\.conda\\envs\\DLenv\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n", + "INFO:tensorflow:Assets written to: C:\\Users\\wei_user\\AppData\\Local\\Temp\\tmptzjvlk_t\\assets\n" + ] + } + ], + "source": [ + "model.save(\"ALLSKY_SFC_PAR_TOT_10020148.h5\")\n", + "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", + "tflite_model = converter.convert()\n", + "\n", + "# Save the model.\n", + "with open('ALLSKY_SFC_PAR_TOT_10020148.tflite', 'wb') as f:\n", + " f.write(tflite_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "allskypar = allskypar.reshape((-1))\n", + "\n", + "def predict(num_prediction, model):\n", + " prediction_list = allskypar[-look_back:]\n", + " \n", + " for _ in range(num_prediction):\n", + " x = prediction_list[-look_back:]\n", + " x = x.reshape((1, look_back, 1))\n", + " out = model.predict(x)[0][0]\n", + " prediction_list = np.append(prediction_list, out)\n", + " prediction_list = prediction_list[look_back-1:]\n", + " \n", + " return prediction_list\n", + " \n", + "def predict_dates(num_prediction):\n", + " last_date = df['INDEX'].values[-1]\n", + " prediction_dates = pd.date_range(last_date, periods=num_prediction+1).tolist()\n", + " return prediction_dates\n", + "\n", + "num_prediction = 30\n", + "forecast = predict(num_prediction, model)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[120.46 107.55223846 103.08728027 97.21777344 91.40464783\n", + " 86.83331299 83.72988892 80.78806305 79.82680511 79.97673798\n", + " 80.69612122 80.42742157 77.98761749 73.56581879 68.69194031\n", + " 66.22226715 66.46763611 67.82965851 69.28223419 71.42388153\n", + " 68.34017181 64.36053467 62.29381943 60.38653183 60.02998352\n", + " 60.08250427 60.14657593 59.68810654 58.92700577 58.22602081\n", + " 57.59521866]\n" + ] + } + ], + "source": [ + "print(forecast)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Sunshine_LSTM.ipynb b/Sunshine_LSTM.ipynb index 19b494b..8c2d6ff 100644 --- a/Sunshine_LSTM.ipynb +++ b/Sunshine_LSTM.ipynb @@ -3,7 +3,6 @@ { "cell_type": "code", "execution_count": 1, - "id": "7e2abf3f", "metadata": {}, "outputs": [], "source": [ @@ -19,7 +18,6 @@ { "cell_type": "code", "execution_count": 2, - "id": "d5802021", "metadata": {}, "outputs": [ { @@ -60,7 +58,6 @@ { "cell_type": "code", "execution_count": 3, - "id": "7cfdf996", "metadata": {}, "outputs": [ { @@ -348,7 +345,6 @@ { "cell_type": "code", "execution_count": 4, - "id": "57d3f6e4", "metadata": {}, "outputs": [], "source": [ @@ -358,7 +354,6 @@ { "cell_type": "code", "execution_count": 5, - "id": "76855f5d", "metadata": {}, "outputs": [ { @@ -476,7 +471,6 @@ { "cell_type": "code", "execution_count": 6, - "id": "9ec9c9a0", "metadata": {}, "outputs": [], "source": [ @@ -488,7 +482,6 @@ { "cell_type": "code", "execution_count": 7, - "id": "4cbd266b", "metadata": {}, "outputs": [ { @@ -517,7 +510,6 @@ { "cell_type": "code", "execution_count": 8, - "id": "9e59770c", "metadata": {}, "outputs": [], "source": [ @@ -530,7 +522,6 @@ { "cell_type": "code", "execution_count": 16, - "id": "148c3570", "metadata": {}, "outputs": [ { @@ -617,7 +608,6 @@ { "cell_type": "code", "execution_count": 18, - "id": "b00bfc77", "metadata": {}, "outputs": [ { @@ -1867,7 +1857,6 @@ { "cell_type": "code", "execution_count": 19, - "id": "20fadf4b", "metadata": {}, "outputs": [ { @@ -1897,7 +1886,6 @@ { "cell_type": "code", "execution_count": 20, - "id": "1a7ec015", "metadata": {}, "outputs": [], "source": [ @@ -1927,7 +1915,6 @@ { "cell_type": "code", "execution_count": 21, - "id": "2bbec5da", "metadata": {}, "outputs": [ { @@ -1951,7 +1938,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1965,7 +1952,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.7.6" } }, "nbformat": 4, diff --git a/Sunshine_LSTM.py b/Sunshine_LSTM.py new file mode 100644 index 0000000..d7032bc --- /dev/null +++ b/Sunshine_LSTM.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[1]: + + +import pandas as pd +import numpy as np +import keras +import tensorflow as tf +from keras.preprocessing.sequence import TimeseriesGenerator +from keras.models import Sequential +from keras.layers import LSTM, Dense + + +# In[2]: + + +filename = "data.csv" +df = pd.read_csv(filename) +print(df.info()) + + +# In[3]: + + +df + + +# In[4]: + + +df = df.drop(['YEAR', 'MO', 'DY','ALLSKY_SFC_UVA','ALLSKY_SFC_UVB','ALLSKY_SFC_SW_DWN','CLRSKY_SFC_SW_DWN','WS2M','ALLSKY_KT','ALLSKY_NKT','ALLSKY_SFC_LW_DWN','CLRSKY_SFC_PAR_TOT'], axis = 1) + + +# In[5]: + + +df + + +# In[6]: + + +allskypar = df['ALLSKY_SFC_PAR_TOT'].values +allskypar = allskypar.reshape((-1,1)) +df.insert(0, 'INDEX', range(1, 1 + len(df))) + + +# In[7]: + + +split_percent = 0.99 +split = int(split_percent*len(allskypar)) + +allskypar_train = allskypar[:split] +allskypar_test = allskypar[split:] + +date_train = df['INDEX'][:split] +date_test = df['INDEX'][split:] + +print(len(allskypar_train)) +print(len(allskypar_test)) + + +# In[8]: + + +look_back = 30 + +train_generator = TimeseriesGenerator(allskypar_train, allskypar_train, length=look_back, batch_size=20) +test_generator = TimeseriesGenerator(allskypar_test, allskypar_test, length=look_back, batch_size=1) + + +# In[16]: + + +model = Sequential() +model.add( + LSTM(100, + activation='relu', + input_shape=(look_back,1)) +) +model.add(Dense(2)) +model.compile(optimizer='adam', loss='mse') + +num_epochs = 25 +model.fit(train_generator, epochs=num_epochs, verbose=1) + + +# In[18]: + + +print(test_generator) +prediction = model.predict(test_generator) + +allskypar_train = allskypar_train.reshape((-1)) +allskypar_test = allskypar_test.reshape((-1)) +prediction = prediction.reshape((-1)) +import plotly.graph_objects as go + +trace1 = go.Scatter( + x = date_train, + y = allskypar_train, + mode = 'lines', + name = 'Data' +) +trace2 = go.Scatter( + x = date_test, + y = prediction, + mode = 'lines', + name = 'Prediction' +) +trace3 = go.Scatter( + x = date_test, + y = allskypar_test, + mode='lines', + name = 'Ground Truth' +) +layout = go.Layout( + title = "LSTM: AllSkyPar", + xaxis = {'title' : "Date"}, + yaxis = {'title' : "allskypar"} +) +fig = go.Figure(data=[trace3, trace2], layout=layout) +fig.show() + + +# In[19]: + + +model.save("ALLSKY_SFC_PAR_TOT_10020148.h5") +converter = tf.lite.TFLiteConverter.from_keras_model(model) +tflite_model = converter.convert() + +# Save the model. +with open('ALLSKY_SFC_PAR_TOT_10020148.tflite', 'wb') as f: + f.write(tflite_model) + + +# In[20]: + + +allskypar = allskypar.reshape((-1)) + +def predict(num_prediction, model): + prediction_list = allskypar[-look_back:] + + for _ in range(num_prediction): + x = prediction_list[-look_back:] + x = x.reshape((1, look_back, 1)) + out = model.predict(x)[0][0] + prediction_list = np.append(prediction_list, out) + prediction_list = prediction_list[look_back-1:] + + return prediction_list + +def predict_dates(num_prediction): + last_date = df['INDEX'].values[-1] + prediction_dates = pd.date_range(last_date, periods=num_prediction+1).tolist() + return prediction_dates + +num_prediction = 30 +forecast = predict(num_prediction, model) + + +# In[21]: + + +print(forecast) +