diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 33c0a98..e8847a6 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -203,7 +203,7 @@ - + diff --git a/serve.py b/serve.py index 2709f15..a81652b 100644 --- a/serve.py +++ b/serve.py @@ -148,7 +148,7 @@ def draw_yqy(config2, origin_data, predict_norm_data, mean_yqy, std_yqy): label_column_num = 3 loss = \ - np.mean((label_norm_data[config.predict_day:][:-3] - predict_norm_data[:-config.predict_day]) ** 2, axis=0) + np.mean((label_norm_data[config.predict_day:, 5:8] - predict_norm_data[:-config.predict_day]) ** 2, axis=0) print("The mean squared error of stock {} is ".format(label_name), loss) # label_X = range(origin_data.data_num - origin_data.train_num - origin_data.start_num_in_test)