mirror of
https://github.com/newnius/YAO-optimizer.git
synced 2025-06-07 07:01:56 +00:00
add files
This commit is contained in:
parent
c11466c958
commit
8fbce50cf8
@ -151,7 +151,7 @@
|
|||||||
<component name="PropertiesComponent">
|
<component name="PropertiesComponent">
|
||||||
<property name="WebServerToolWindowFactoryState" value="false" />
|
<property name="WebServerToolWindowFactoryState" value="false" />
|
||||||
<property name="aspect.path.notification.shown" value="true" />
|
<property name="aspect.path.notification.shown" value="true" />
|
||||||
<property name="com.android.tools.idea.instantapp.provision.ProvisionBeforeRunTaskProvider.myTimeStamp" value="1588157839783" />
|
<property name="com.android.tools.idea.instantapp.provision.ProvisionBeforeRunTaskProvider.myTimeStamp" value="1588157865997" />
|
||||||
<property name="go.gopath.indexing.explicitly.defined" value="true" />
|
<property name="go.gopath.indexing.explicitly.defined" value="true" />
|
||||||
<property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" />
|
<property name="nodejs_interpreter_path.stuck_in_default_project" value="undefined stuck path" />
|
||||||
<property name="nodejs_npm_path_reset_for_default_project" value="true" />
|
<property name="nodejs_npm_path_reset_for_default_project" value="true" />
|
||||||
|
4
main.py
4
main.py
@ -164,11 +164,11 @@ def main(config):
|
|||||||
|
|
||||||
if config.do_train:
|
if config.do_train:
|
||||||
train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data()
|
train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data()
|
||||||
train(config, train_X, train_Y, valid_X, valid_Y)
|
model = train(config, train_X, train_Y, valid_X, valid_Y)
|
||||||
|
|
||||||
if config.do_predict:
|
if config.do_predict:
|
||||||
test_X, test_Y = data_gainer.get_test_data(return_label_data=True)
|
test_X, test_Y = data_gainer.get_test_data(return_label_data=True)
|
||||||
pred_result = predict(config, test_X)
|
pred_result = predict(config, test_X, model)
|
||||||
draw(config, data_gainer, pred_result)
|
draw(config, data_gainer, pred_result)
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,19 +76,19 @@ def train(config, train_X, train_Y, valid_X, valid_Y):
|
|||||||
if bad_epoch >= config.patience:
|
if bad_epoch >= config.patience:
|
||||||
print(" The training stops early in epoch {}".format(epoch))
|
print(" The training stops early in epoch {}".format(epoch))
|
||||||
break
|
break
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def predict(config, test_X):
|
def predict(config, test_X, model):
|
||||||
config.dropout_rate = 0.1
|
config.dropout_rate = 0.1
|
||||||
tf.reset_default_graph()
|
|
||||||
|
|
||||||
with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE):
|
#with tf.variable_scope("stock_predict", reuse=tf.AUTO_REUSE):
|
||||||
model = Model(config)
|
# model = Model(config)
|
||||||
|
|
||||||
test_len = len(test_X)
|
test_len = len(test_X)
|
||||||
with tf.Session() as sess:
|
with tf.Session() as sess:
|
||||||
module_file = tf.train.latest_checkpoint(config.model_save_path)
|
#module_file = tf.train.latest_checkpoint(config.model_save_path)
|
||||||
model.saver.restore(sess, module_file)
|
#model.saver.restore(sess, module_file)
|
||||||
|
|
||||||
result = np.zeros((test_len * config.time_step, config.output_size))
|
result = np.zeros((test_len * config.time_step, config.output_size))
|
||||||
for step in range(test_len):
|
for step in range(test_len):
|
||||||
|
Loading…
Reference in New Issue
Block a user