diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index 5ca6a0d..354c3bb 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -3,7 +3,6 @@
-
@@ -22,6 +21,7 @@
+
@@ -29,28 +29,29 @@
+
-
+
-
+
-
+
-
-
+
+
@@ -58,11 +59,11 @@
-
+
-
-
+
+
@@ -70,6 +71,11 @@
+
+
+
+
+
@@ -122,6 +128,16 @@
+
+
+
+
+
+
+
+
+
+
@@ -137,7 +153,7 @@
-
+
@@ -168,25 +184,26 @@
1588152877746
-
+
-
+
+
-
+
-
+
@@ -239,22 +256,25 @@
-
+
-
-
+
+
-
+
-
+
+
+
+
-
-
+
+
-
+
diff --git a/main.py b/main.py
index 9d6b71b..51505a6 100644
--- a/main.py
+++ b/main.py
@@ -164,11 +164,11 @@ def main(config):
if config.do_train:
train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data()
- train(config, train_X, train_Y, valid_X, valid_Y)
+ model = train(config, train_X, train_Y, valid_X, valid_Y)
if config.do_predict:
test_X, test_Y = data_gainer.get_test_data(return_label_data=True)
- pred_result = predict(config, test_X)
+ pred_result = predict(config, test_X, model)
draw(config, data_gainer, pred_result)
diff --git a/model/model_tensorflow.py b/model/model_tensorflow.py
index c2be078..634f724 100644
--- a/model/model_tensorflow.py
+++ b/model/model_tensorflow.py
@@ -76,18 +76,19 @@ def train(config, train_X, train_Y, valid_X, valid_Y):
if bad_epoch >= config.patience:
print(" The training stops early in epoch {}".format(epoch))
break
+ return model
-def predict(config, test_X):
+def predict(config, test_X, model):
config.dropout_rate = 0.1
- with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE):
- model = Model(config)
+ #with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE):
+ #model = Model(config)
test_len = len(test_X)
with tf.Session() as sess:
- module_file = tf.train.latest_checkpoint(config.model_save_path)
- model.saver.restore(sess, module_file)
+ #module_file = tf.train.latest_checkpoint(config.model_save_path)
+ #model.saver.restore(sess, module_file)
result = np.zeros((test_len * config.time_step, config.output_size))
for step in range(test_len):