mirror of
https://github.com/newnius/YAO-optimizer.git
synced 2025-06-06 22:51:55 +00:00
输入时的train_x不包括要预测的值
This commit is contained in:
parent
a394657a54
commit
3e7aa1fc91
35
main.py
35
main.py
@ -9,14 +9,18 @@ frame = "tensorflow"
|
||||
|
||||
|
||||
class Config:
|
||||
feature_columns = list([2, 5])
|
||||
# feature_columns = list(range(0,8))
|
||||
# label_columns = [5,6,7]
|
||||
feature_columns = list([2,5])#comment yqy
|
||||
# feature_columns = list([2]) #add yqy
|
||||
label_columns = [5]
|
||||
feature_and_label_columns = feature_columns + label_columns
|
||||
label_in_feature_columns = (lambda x, y: [x.index(i) for i in y])(feature_columns, label_columns)
|
||||
|
||||
predict_day = 1
|
||||
|
||||
input_size = len(feature_columns)
|
||||
# input_size = len(feature_columns)#comment yqy
|
||||
input_size = len( list([2]))#add yqy
|
||||
output_size = len(label_columns)
|
||||
|
||||
hidden_size = 128
|
||||
@ -24,8 +28,8 @@ class Config:
|
||||
dropout_rate = 0.2
|
||||
time_step = 5
|
||||
|
||||
do_train = True
|
||||
# do_train = False
|
||||
# do_train = True
|
||||
do_train = False
|
||||
do_predict = True
|
||||
add_train = False
|
||||
shuffle_train_data = True
|
||||
@ -48,12 +52,12 @@ class Config:
|
||||
continue_flag = "continue_"
|
||||
|
||||
#comment yqy
|
||||
# train_data_path = "./data/stock_data.csv"
|
||||
train_data_path = "./data/stock_data.csv"
|
||||
model_save_path = "./checkpoint/"
|
||||
figure_save_path = "./figure/"
|
||||
#comment end
|
||||
# add yqy
|
||||
train_data_path = "./data/stock_data_30.csv"
|
||||
# train_data_path = "./data/stock_data_30.csv"
|
||||
# model_save_path = "./checkpoint/30/"
|
||||
# figure_save_path = "./figure/30/"
|
||||
# add end
|
||||
@ -88,7 +92,8 @@ class Data:
|
||||
return init_data.values, init_data.columns.tolist()
|
||||
|
||||
def get_train_and_valid_data(self):
|
||||
feature_data = self.norm_data[:self.train_num]
|
||||
# feature_data = self.norm_data[:self.train_num] # comment yqy
|
||||
feature_data = self.norm_data[:self.train_num][:,1][:,np.newaxis] # add yqy
|
||||
label_data = self.norm_data[self.config.predict_day: self.config.predict_day + self.train_num,
|
||||
self.config.label_in_feature_columns]
|
||||
if not self.config.do_continue_train:
|
||||
@ -169,7 +174,14 @@ def draw_yqy(config, origin_data, predict_norm_data,mean_yqy,std_yqy):# 这里or
|
||||
label_name = 'high'
|
||||
label_column_num = 1
|
||||
|
||||
loss = np.mean((label_norm_data[config.predict_day:][:,1] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)[1]
|
||||
loss = np.mean((label_norm_data[config.predict_day:][:,1][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
|
||||
|
||||
# loss = np.mean((label_norm_data[config.predict_day:][:,5][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
|
||||
# loss2 = np.mean((label_norm_data[config.predict_day:][:,6][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
|
||||
# loss3 = np.mean((label_norm_data[config.predict_day:][:,7][:,np.newaxis] - predict_norm_data[:-config.predict_day][0:]) ** 2, axis=0)
|
||||
|
||||
|
||||
|
||||
print("The mean squared error of stock {} is ".format(label_name), loss)
|
||||
|
||||
# label_X = range(origin_data.data_num - origin_data.train_num - origin_data.start_num_in_test)
|
||||
@ -200,13 +212,14 @@ def main(config):
|
||||
|
||||
if config.do_predict:
|
||||
# add yqy
|
||||
test_data_yqy = pd.read_csv("./data/test_data.csv",usecols=list([2, 5]))
|
||||
test_data_values_yqy=test_data_yqy.values[:]
|
||||
test_data_yqy = pd.read_csv("./data/test_data.csv",usecols=list([2,5]))
|
||||
test_data_values_yqy=test_data_yqy.values
|
||||
# test_data_yqy=[104.3,104.39]
|
||||
test_X =data_gainer.get_test_data_yqy(test_data_values_yqy)
|
||||
# add end
|
||||
# test_X, test_Y = data_gainer.get_test_data(return_label_data=True)# comment yqy
|
||||
pred_result = predict(config, test_X)
|
||||
# pred_result = predict(config, test_X)
|
||||
pred_result = predict(config,test_X[:,:,0][:,:,np.newaxis])
|
||||
# draw(config, data_gainer, pred_result)# comment yqy
|
||||
draw_yqy(config, test_data_values_yqy, pred_result,mean_yqy,std_yqy)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user