From 3e7aa1fc91196d33cb322649304a07114331d8bb Mon Sep 17 00:00:00 2001 From: yexiaoqi Date: Sat, 2 May 2020 13:05:29 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BE=93=E5=85=A5=E6=97=B6=E7=9A=84train=5Fx?= =?UTF-8?q?=E4=B8=8D=E5=8C=85=E6=8B=AC=E8=A6=81=E9=A2=84=E6=B5=8B=E7=9A=84?= =?UTF-8?q?=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 38a9c36..aa4e0d8 100644 --- a/main.py +++ b/main.py @@ -9,14 +9,18 @@ frame = "tensorflow" class Config: - feature_columns = list([2, 5]) + # feature_columns = list(range(0,8)) + # label_columns = [5,6,7] + feature_columns = list([2,5])#comment yqy + # feature_columns = list([2]) #add yqy label_columns = [5] feature_and_label_columns = feature_columns + label_columns label_in_feature_columns = (lambda x, y: [x.index(i) for i in y])(feature_columns, label_columns) predict_day = 1 - input_size = len(feature_columns) + # input_size = len(feature_columns)#comment yqy + input_size = len( list([2]))#add yqy output_size = len(label_columns) hidden_size = 128 @@ -24,8 +28,8 @@ class Config: dropout_rate = 0.2 time_step = 5 - do_train = True - # do_train = False + # do_train = True + do_train = False do_predict = True add_train = False shuffle_train_data = True @@ -48,12 +52,12 @@ class Config: continue_flag = "continue_" #comment yqy - # train_data_path = "./data/stock_data.csv" + train_data_path = "./data/stock_data.csv" model_save_path = "./checkpoint/" figure_save_path = "./figure/" #comment end # add yqy - train_data_path = "./data/stock_data_30.csv" + # train_data_path = "./data/stock_data_30.csv" # model_save_path = "./checkpoint/30/" # figure_save_path = "./figure/30/" # add end @@ -88,7 +92,8 @@ class Data: return init_data.values, init_data.columns.tolist() def get_train_and_valid_data(self): - feature_data = self.norm_data[:self.train_num] + # feature_data = self.norm_data[:self.train_num] # comment yqy + feature_data = self.norm_data[:self.train_num][:,1][:,np.newaxis] # add yqy label_data = self.norm_data[self.config.predict_day: self.config.predict_day + self.train_num, self.config.label_in_feature_columns] if not self.config.do_continue_train: @@ -169,7 +174,14 @@ def draw_yqy(config, origin_data, predict_norm_data,mean_yqy,std_yqy):# 这里or label_name = 'high' label_column_num = 1 - loss = np.mean((label_norm_data[config.predict_day:][:,1] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)[1] + loss = np.mean((label_norm_data[config.predict_day:][:,1][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0) + + # loss = np.mean((label_norm_data[config.predict_day:][:,5][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0) + # loss2 = np.mean((label_norm_data[config.predict_day:][:,6][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0) + # loss3 = np.mean((label_norm_data[config.predict_day:][:,7][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0) + + + print("The mean squared error of stock {} is ".format(label_name), loss) # label_X = range(origin_data.data_num - origin_data.train_num - origin_data.start_num_in_test) @@ -200,13 +212,14 @@ def main(config): if config.do_predict: # add yqy - test_data_yqy = pd.read_csv("./data/test_data.csv",usecols=list([2, 5])) - test_data_values_yqy=test_data_yqy.values[:] + test_data_yqy = pd.read_csv("./data/test_data.csv",usecols=list([2,5])) + test_data_values_yqy=test_data_yqy.values # test_data_yqy=[104.3,104.39] test_X =data_gainer.get_test_data_yqy(test_data_values_yqy) # add end # test_X, test_Y = data_gainer.get_test_data(return_label_data=True)# comment yqy - pred_result = predict(config, test_X) + # pred_result = predict(config, test_X) + pred_result = predict(config,test_X[:,:,0][:,:,np.newaxis]) # draw(config, data_gainer, pred_result)# comment yqy draw_yqy(config, test_data_values_yqy, pred_result,mean_yqy,std_yqy)