diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 5ca6a0d..354c3bb 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -3,7 +3,6 @@ - - + - - + + @@ -58,11 +59,11 @@ - + - - + + @@ -70,6 +71,11 @@ + + + + + @@ -122,6 +128,16 @@ + + + + + + + + + + @@ -137,7 +153,7 @@ - + @@ -168,25 +184,26 @@ - + - + - + @@ -239,22 +256,25 @@ - + - - + + - + - + + + + - - + + - + diff --git a/main.py b/main.py index 9d6b71b..51505a6 100644 --- a/main.py +++ b/main.py @@ -164,11 +164,11 @@ def main(config): if config.do_train: train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data() - train(config, train_X, train_Y, valid_X, valid_Y) + model = train(config, train_X, train_Y, valid_X, valid_Y) if config.do_predict: test_X, test_Y = data_gainer.get_test_data(return_label_data=True) - pred_result = predict(config, test_X) + pred_result = predict(config, test_X, model) draw(config, data_gainer, pred_result) diff --git a/model/model_tensorflow.py b/model/model_tensorflow.py index c2be078..634f724 100644 --- a/model/model_tensorflow.py +++ b/model/model_tensorflow.py @@ -76,18 +76,19 @@ def train(config, train_X, train_Y, valid_X, valid_Y): if bad_epoch >= config.patience: print(" The training stops early in epoch {}".format(epoch)) break + return model -def predict(config, test_X): +def predict(config, test_X, model): config.dropout_rate = 0.1 - with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE): - model = Model(config) + #with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE): + #model = Model(config) test_len = len(test_X) with tf.Session() as sess: - module_file = tf.train.latest_checkpoint(config.model_save_path) - model.saver.restore(sess, module_file) + #module_file = tf.train.latest_checkpoint(config.model_save_path) + #model.saver.restore(sess, module_file) result = np.zeros((test_len * config.time_step, config.output_size)) for step in range(test_len):