From 8fbce50cf88ecf80c3c194510cdc4728ffd7bc74 Mon Sep 17 00:00:00 2001 From: Newnius Date: Wed, 29 Apr 2020 18:59:40 +0800 Subject: [PATCH] add files --- .idea/workspace.xml | 2 +- main.py | 4 ++-- model/model_tensorflow.py | 12 ++++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 2c2730e..5808446 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -151,7 +151,7 @@ - + diff --git a/main.py b/main.py index 9d6b71b..51505a6 100644 --- a/main.py +++ b/main.py @@ -164,11 +164,11 @@ def main(config): if config.do_train: train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data() - train(config, train_X, train_Y, valid_X, valid_Y) + model = train(config, train_X, train_Y, valid_X, valid_Y) if config.do_predict: test_X, test_Y = data_gainer.get_test_data(return_label_data=True) - pred_result = predict(config, test_X) + pred_result = predict(config, test_X, model) draw(config, data_gainer, pred_result) diff --git a/model/model_tensorflow.py b/model/model_tensorflow.py index 444d6b1..8d7be13 100644 --- a/model/model_tensorflow.py +++ b/model/model_tensorflow.py @@ -76,19 +76,19 @@ def train(config, train_X, train_Y, valid_X, valid_Y): if bad_epoch >= config.patience: print(" The training stops early in epoch {}".format(epoch)) break + return model -def predict(config, test_X): +def predict(config, test_X, model): config.dropout_rate = 0.1 - tf.reset_default_graph() - with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE): - model = Model(config) + #with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE): + # model = Model(config) test_len = len(test_X) with tf.Session() as sess: - module_file = tf.train.latest_checkpoint(config.model_save_path) - model.saver.restore(sess, module_file) + #module_file = tf.train.latest_checkpoint(config.model_save_path) + #model.saver.restore(sess, module_file) result = np.zeros((test_len * config.time_step, config.output_size)) for step in range(test_len):