diff --git a/serve.py b/serve.py index 3ca79da..ce0e7d4 100644 --- a/serve.py +++ b/serve.py @@ -51,6 +51,7 @@ class Config: continue_flag = "continue_" train_data_path = "./data.csv" + test_data_path = "./test.csv" model_save_path = "./checkpoint/" figure_save_path = "./figure/" do_figure_save = False @@ -112,7 +113,7 @@ class Data: def get_test_data(self, return_label_data=False): init_data = pd.read_csv( - self.config.train_data_path, + self.config.test_data_path, usecols=self.config.feature_and_label_columns ) data, data_column_name = init_data.values, init_data.columns.tolist()