diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 5808446..40f7d01 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -151,7 +151,7 @@ - + diff --git a/main.py b/main.py index 173145d..9d6b71b 100644 --- a/main.py +++ b/main.py @@ -164,11 +164,11 @@ def main(config): if config.do_train: train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data() - model = train(config, train_X, train_Y, valid_X, valid_Y) - + train(config, train_X, train_Y, valid_X, valid_Y) + if config.do_predict: test_X, test_Y = data_gainer.get_test_data(return_label_data=True) - pred_result = predict(config, test_X, model) + pred_result = predict(config, test_X) draw(config, data_gainer, pred_result) diff --git a/model/model_tensorflow.py b/model/model_tensorflow.py index 8d7be13..444d6b1 100644 --- a/model/model_tensorflow.py +++ b/model/model_tensorflow.py @@ -76,19 +76,19 @@ def train(config, train_X, train_Y, valid_X, valid_Y): if bad_epoch >= config.patience: print(" The training stops early in epoch {}".format(epoch)) break - return model -def predict(config, test_X, model): +def predict(config, test_X): config.dropout_rate = 0.1 + tf.reset_default_graph() - #with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE): - # model = Model(config) + with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE): + model = Model(config) test_len = len(test_X) with tf.Session() as sess: - #module_file = tf.train.latest_checkpoint(config.model_save_path) - #model.saver.restore(sess, module_file) + module_file = tf.train.latest_checkpoint(config.model_save_path) + model.saver.restore(sess, module_file) result = np.zeros((test_len * config.time_step, config.output_size)) for step in range(test_len):