1
0
mirror of https://github.com/newnius/YAO-optimizer.git synced 2025-12-15 09:06:43 +00:00

add files

This commit is contained in:
2020-04-29 18:50:53 +08:00
parent dca3c383ea
commit 30ea099b2b
3 changed files with 34 additions and 34 deletions

View File

@@ -76,19 +76,18 @@ def train(config, train_X, train_Y, valid_X, valid_Y):
if bad_epoch >= config.patience:
print(" The training stops early in epoch {}".format(epoch))
break
return model
def predict(config, test_X, model):
def predict(config, test_X):
config.dropout_rate = 0.1
#with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE):
#model = Model(config)
with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE):
model = Model(config)
test_len = len(test_X)
with tf.Session() as sess:
#module_file = tf.train.latest_checkpoint(config.model_save_path)
#model.saver.restore(sess, module_file)
module_file = tf.train.latest_checkpoint(config.model_save_path + config.model_name)
model.saver.restore(sess, module_file)
result = np.zeros((test_len * config.time_step, config.output_size))
for step in range(test_len):