mirror of
https://github.com/newnius/YAO-optimizer.git
synced 2025-12-15 17:06:44 +00:00
add files
This commit is contained in:
@@ -76,19 +76,19 @@ def train(config, train_X, train_Y, valid_X, valid_Y):
|
||||
if bad_epoch >= config.patience:
|
||||
print(" The training stops early in epoch {}".format(epoch))
|
||||
break
|
||||
return model
|
||||
|
||||
|
||||
def predict(config, test_X):
|
||||
def predict(config, test_X, model):
|
||||
config.dropout_rate = 0.1
|
||||
tf.reset_default_graph()
|
||||
|
||||
with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE):
|
||||
model = Model(config)
|
||||
#with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE):
|
||||
# model = Model(config)
|
||||
|
||||
test_len = len(test_X)
|
||||
with tf.Session() as sess:
|
||||
module_file = tf.train.latest_checkpoint(config.model_save_path)
|
||||
model.saver.restore(sess, module_file)
|
||||
#module_file = tf.train.latest_checkpoint(config.model_save_path)
|
||||
#model.saver.restore(sess, module_file)
|
||||
|
||||
result = np.zeros((test_len * config.time_step, config.output_size))
|
||||
for step in range(test_len):
|
||||
|
||||
Reference in New Issue
Block a user