From 30ea099b2b7edd3f914fa05c79ae197d646e1fae Mon Sep 17 00:00:00 2001 From: Newnius Date: Wed, 29 Apr 2020 18:50:53 +0800 Subject: [PATCH] add files --- .idea/workspace.xml | 53 ++++++++++++++++++++------------------- main.py | 4 +-- model/model_tensorflow.py | 11 ++++---- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 354c3bb..1ae2f58 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -3,6 +3,8 @@ + + - + - - + + @@ -59,11 +61,11 @@ - + - - + + @@ -153,7 +155,7 @@ - + @@ -184,26 +186,25 @@ - - - + - + @@ -256,29 +257,29 @@ - - - - - - - - - - - - + + + + + + + + + + + + diff --git a/main.py b/main.py index 51505a6..9d6b71b 100644 --- a/main.py +++ b/main.py @@ -164,11 +164,11 @@ def main(config): if config.do_train: train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data() - model = train(config, train_X, train_Y, valid_X, valid_Y) + train(config, train_X, train_Y, valid_X, valid_Y) if config.do_predict: test_X, test_Y = data_gainer.get_test_data(return_label_data=True) - pred_result = predict(config, test_X, model) + pred_result = predict(config, test_X) draw(config, data_gainer, pred_result) diff --git a/model/model_tensorflow.py b/model/model_tensorflow.py index 634f724..a07ea22 100644 --- a/model/model_tensorflow.py +++ b/model/model_tensorflow.py @@ -76,19 +76,18 @@ def train(config, train_X, train_Y, valid_X, valid_Y): if bad_epoch >= config.patience: print(" The training stops early in epoch {}".format(epoch)) break - return model -def predict(config, test_X, model): +def predict(config, test_X): config.dropout_rate = 0.1 - #with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE): - #model = Model(config) + with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE): + model = Model(config) test_len = len(test_X) with tf.Session() as sess: - #module_file = tf.train.latest_checkpoint(config.model_save_path) - #model.saver.restore(sess, module_file) + module_file = tf.train.latest_checkpoint(config.model_save_path + config.model_name) + model.saver.restore(sess, module_file) result = np.zeros((test_len * config.time_step, config.output_size)) for step in range(test_len):