diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index 354c3bb..1ae2f58 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -3,6 +3,8 @@
+
+
@@ -35,23 +37,23 @@
-
+
-
+
-
+
-
-
+
+
@@ -59,11 +61,11 @@
-
+
-
-
+
+
@@ -153,7 +155,7 @@
-
+
@@ -184,26 +186,25 @@
1588152877746
-
+
-
+
-
-
+
-
+
@@ -256,29 +257,29 @@
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/main.py b/main.py
index 51505a6..9d6b71b 100644
--- a/main.py
+++ b/main.py
@@ -164,11 +164,11 @@ def main(config):
if config.do_train:
train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data()
- model = train(config, train_X, train_Y, valid_X, valid_Y)
+ train(config, train_X, train_Y, valid_X, valid_Y)
if config.do_predict:
test_X, test_Y = data_gainer.get_test_data(return_label_data=True)
- pred_result = predict(config, test_X, model)
+ pred_result = predict(config, test_X)
draw(config, data_gainer, pred_result)
diff --git a/model/model_tensorflow.py b/model/model_tensorflow.py
index 634f724..a07ea22 100644
--- a/model/model_tensorflow.py
+++ b/model/model_tensorflow.py
@@ -76,19 +76,18 @@ def train(config, train_X, train_Y, valid_X, valid_Y):
if bad_epoch >= config.patience:
print(" The training stops early in epoch {}".format(epoch))
break
- return model
-def predict(config, test_X, model):
+def predict(config, test_X):
config.dropout_rate = 0.1
- #with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE):
- #model = Model(config)
+ with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE):
+ model = Model(config)
test_len = len(test_X)
with tf.Session() as sess:
- #module_file = tf.train.latest_checkpoint(config.model_save_path)
- #model.saver.restore(sess, module_file)
+ module_file = tf.train.latest_checkpoint(config.model_save_path + config.model_name)
+ model.saver.restore(sess, module_file)
result = np.zeros((test_len * config.time_step, config.output_size))
for step in range(test_len):