1
0
mirror of https://github.com/newnius/YAO-optimizer.git synced 2025-06-06 22:51:55 +00:00

输入时的train_x不包括要预测的值

This commit is contained in:
yexiaoqi 2020-05-02 13:05:29 +08:00
parent a394657a54
commit 3e7aa1fc91

35
main.py
View File

@ -9,14 +9,18 @@ frame = "tensorflow"
class Config:
feature_columns = list([2, 5])
# feature_columns = list(range(0,8))
# label_columns = [5,6,7]
feature_columns = list([2,5])#comment yqy
# feature_columns = list([2]) #add yqy
label_columns = [5]
feature_and_label_columns = feature_columns + label_columns
label_in_feature_columns = (lambda x, y: [x.index(i) for i in y])(feature_columns, label_columns)
predict_day = 1
input_size = len(feature_columns)
# input_size = len(feature_columns)#comment yqy
input_size = len( list([2]))#add yqy
output_size = len(label_columns)
hidden_size = 128
@ -24,8 +28,8 @@ class Config:
dropout_rate = 0.2
time_step = 5
do_train = True
# do_train = False
# do_train = True
do_train = False
do_predict = True
add_train = False
shuffle_train_data = True
@ -48,12 +52,12 @@ class Config:
continue_flag = "continue_"
#comment yqy
# train_data_path = "./data/stock_data.csv"
train_data_path = "./data/stock_data.csv"
model_save_path = "./checkpoint/"
figure_save_path = "./figure/"
#comment end
# add yqy
train_data_path = "./data/stock_data_30.csv"
# train_data_path = "./data/stock_data_30.csv"
# model_save_path = "./checkpoint/30/"
# figure_save_path = "./figure/30/"
# add end
@ -88,7 +92,8 @@ class Data:
return init_data.values, init_data.columns.tolist()
def get_train_and_valid_data(self):
feature_data = self.norm_data[:self.train_num]
# feature_data = self.norm_data[:self.train_num] # comment yqy
feature_data = self.norm_data[:self.train_num][:,1][:,np.newaxis] # add yqy
label_data = self.norm_data[self.config.predict_day: self.config.predict_day + self.train_num,
self.config.label_in_feature_columns]
if not self.config.do_continue_train:
@ -169,7 +174,14 @@ def draw_yqy(config, origin_data, predict_norm_data,mean_yqy,std_yqy):# 这里or
label_name = 'high'
label_column_num = 1
loss = np.mean((label_norm_data[config.predict_day:][:,1] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)[1]
loss = np.mean((label_norm_data[config.predict_day:][:,1][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
# loss = np.mean((label_norm_data[config.predict_day:][:,5][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
# loss2 = np.mean((label_norm_data[config.predict_day:][:,6][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
# loss3 = np.mean((label_norm_data[config.predict_day:][:,7][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
print("The mean squared error of stock {} is ".format(label_name), loss)
# label_X = range(origin_data.data_num - origin_data.train_num - origin_data.start_num_in_test)
@ -200,13 +212,14 @@ def main(config):
if config.do_predict:
# add yqy
test_data_yqy = pd.read_csv("./data/test_data.csv",usecols=list([2, 5]))
test_data_values_yqy=test_data_yqy.values[:]
test_data_yqy = pd.read_csv("./data/test_data.csv",usecols=list([2,5]))
test_data_values_yqy=test_data_yqy.values
# test_data_yqy=[104.3,104.39]
test_X =data_gainer.get_test_data_yqy(test_data_values_yqy)
# add end
# test_X, test_Y = data_gainer.get_test_data(return_label_data=True)# comment yqy
pred_result = predict(config, test_X)
# pred_result = predict(config, test_X)
pred_result = predict(config,test_X[:,:,0][:,:,np.newaxis])
# draw(config, data_gainer, pred_result)# comment yqy
draw_yqy(config, test_data_values_yqy, pred_result,mean_yqy,std_yqy)