mirror of
https://github.com/newnius/YAO-optimizer.git
synced 2025-06-07 23:21:55 +00:00
输入时的train_x不包括要预测的值
This commit is contained in:
parent
a394657a54
commit
3e7aa1fc91
33
main.py
33
main.py
@ -9,14 +9,18 @@ frame = "tensorflow"
|
|||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
feature_columns = list([2, 5])
|
# feature_columns = list(range(0,8))
|
||||||
|
# label_columns = [5,6,7]
|
||||||
|
feature_columns = list([2,5])#comment yqy
|
||||||
|
# feature_columns = list([2]) #add yqy
|
||||||
label_columns = [5]
|
label_columns = [5]
|
||||||
feature_and_label_columns = feature_columns + label_columns
|
feature_and_label_columns = feature_columns + label_columns
|
||||||
label_in_feature_columns = (lambda x, y: [x.index(i) for i in y])(feature_columns, label_columns)
|
label_in_feature_columns = (lambda x, y: [x.index(i) for i in y])(feature_columns, label_columns)
|
||||||
|
|
||||||
predict_day = 1
|
predict_day = 1
|
||||||
|
|
||||||
input_size = len(feature_columns)
|
# input_size = len(feature_columns)#comment yqy
|
||||||
|
input_size = len( list([2]))#add yqy
|
||||||
output_size = len(label_columns)
|
output_size = len(label_columns)
|
||||||
|
|
||||||
hidden_size = 128
|
hidden_size = 128
|
||||||
@ -24,8 +28,8 @@ class Config:
|
|||||||
dropout_rate = 0.2
|
dropout_rate = 0.2
|
||||||
time_step = 5
|
time_step = 5
|
||||||
|
|
||||||
do_train = True
|
# do_train = True
|
||||||
# do_train = False
|
do_train = False
|
||||||
do_predict = True
|
do_predict = True
|
||||||
add_train = False
|
add_train = False
|
||||||
shuffle_train_data = True
|
shuffle_train_data = True
|
||||||
@ -48,12 +52,12 @@ class Config:
|
|||||||
continue_flag = "continue_"
|
continue_flag = "continue_"
|
||||||
|
|
||||||
#comment yqy
|
#comment yqy
|
||||||
# train_data_path = "./data/stock_data.csv"
|
train_data_path = "./data/stock_data.csv"
|
||||||
model_save_path = "./checkpoint/"
|
model_save_path = "./checkpoint/"
|
||||||
figure_save_path = "./figure/"
|
figure_save_path = "./figure/"
|
||||||
#comment end
|
#comment end
|
||||||
# add yqy
|
# add yqy
|
||||||
train_data_path = "./data/stock_data_30.csv"
|
# train_data_path = "./data/stock_data_30.csv"
|
||||||
# model_save_path = "./checkpoint/30/"
|
# model_save_path = "./checkpoint/30/"
|
||||||
# figure_save_path = "./figure/30/"
|
# figure_save_path = "./figure/30/"
|
||||||
# add end
|
# add end
|
||||||
@ -88,7 +92,8 @@ class Data:
|
|||||||
return init_data.values, init_data.columns.tolist()
|
return init_data.values, init_data.columns.tolist()
|
||||||
|
|
||||||
def get_train_and_valid_data(self):
|
def get_train_and_valid_data(self):
|
||||||
feature_data = self.norm_data[:self.train_num]
|
# feature_data = self.norm_data[:self.train_num] # comment yqy
|
||||||
|
feature_data = self.norm_data[:self.train_num][:,1][:,np.newaxis] # add yqy
|
||||||
label_data = self.norm_data[self.config.predict_day: self.config.predict_day + self.train_num,
|
label_data = self.norm_data[self.config.predict_day: self.config.predict_day + self.train_num,
|
||||||
self.config.label_in_feature_columns]
|
self.config.label_in_feature_columns]
|
||||||
if not self.config.do_continue_train:
|
if not self.config.do_continue_train:
|
||||||
@ -169,7 +174,14 @@ def draw_yqy(config, origin_data, predict_norm_data,mean_yqy,std_yqy):# 这里or
|
|||||||
label_name = 'high'
|
label_name = 'high'
|
||||||
label_column_num = 1
|
label_column_num = 1
|
||||||
|
|
||||||
loss = np.mean((label_norm_data[config.predict_day:][:,1] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)[1]
|
loss = np.mean((label_norm_data[config.predict_day:][:,1][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
|
||||||
|
|
||||||
|
# loss = np.mean((label_norm_data[config.predict_day:][:,5][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
|
||||||
|
# loss2 = np.mean((label_norm_data[config.predict_day:][:,6][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
|
||||||
|
# loss3 = np.mean((label_norm_data[config.predict_day:][:,7][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("The mean squared error of stock {} is ".format(label_name), loss)
|
print("The mean squared error of stock {} is ".format(label_name), loss)
|
||||||
|
|
||||||
# label_X = range(origin_data.data_num - origin_data.train_num - origin_data.start_num_in_test)
|
# label_X = range(origin_data.data_num - origin_data.train_num - origin_data.start_num_in_test)
|
||||||
@ -201,12 +213,13 @@ def main(config):
|
|||||||
if config.do_predict:
|
if config.do_predict:
|
||||||
# add yqy
|
# add yqy
|
||||||
test_data_yqy = pd.read_csv("./data/test_data.csv",usecols=list([2,5]))
|
test_data_yqy = pd.read_csv("./data/test_data.csv",usecols=list([2,5]))
|
||||||
test_data_values_yqy=test_data_yqy.values[:]
|
test_data_values_yqy=test_data_yqy.values
|
||||||
# test_data_yqy=[104.3,104.39]
|
# test_data_yqy=[104.3,104.39]
|
||||||
test_X =data_gainer.get_test_data_yqy(test_data_values_yqy)
|
test_X =data_gainer.get_test_data_yqy(test_data_values_yqy)
|
||||||
# add end
|
# add end
|
||||||
# test_X, test_Y = data_gainer.get_test_data(return_label_data=True)# comment yqy
|
# test_X, test_Y = data_gainer.get_test_data(return_label_data=True)# comment yqy
|
||||||
pred_result = predict(config, test_X)
|
# pred_result = predict(config, test_X)
|
||||||
|
pred_result = predict(config,test_X[:,:,0][:,:,np.newaxis])
|
||||||
# draw(config, data_gainer, pred_result)# comment yqy
|
# draw(config, data_gainer, pred_result)# comment yqy
|
||||||
draw_yqy(config, test_data_values_yqy, pred_result,mean_yqy,std_yqy)
|
draw_yqy(config, test_data_values_yqy, pred_result,mean_yqy,std_yqy)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user