diff --git a/.idea/workspace.xml b/.idea/workspace.xml index f14b820..754e9f9 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -134,7 +134,7 @@ - + diff --git a/model/model_tensorflow.py b/model/model_tensorflow.py index 846a516..1e54b7d 100644 --- a/model/model_tensorflow.py +++ b/model/model_tensorflow.py @@ -46,7 +46,7 @@ def train(config, train_X, train_Y, valid_X, valid_Y): bad_epoch = 0 for epoch in range(config.epoch): print("Epoch {}/{}".format(epoch, config.epoch)) - # 训练 + train_loss_array = [] for step in range(train_len // config.batch_size): feed_dict = {model.X: train_X[step * config.batch_size: (step + 1) * config.batch_size], @@ -54,7 +54,7 @@ def train(config, train_X, train_Y, valid_X, valid_Y): train_loss, _ = sess.run([model.loss, model.optim], feed_dict=feed_dict) train_loss_array.append(train_loss) - # 验证与早停 + valid_loss_array = [] for step in range(valid_len // config.batch_size): feed_dict = {model.X: valid_X[step * config.batch_size: (step + 1) * config.batch_size],