mirror of
https://github.com/newnius/YAO-optimizer.git
synced 2025-12-16 01:16:44 +00:00
add files
This commit is contained in:
98
model/model_pytorch.py
Normal file
98
model/model_pytorch.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
from torch.nn import Module, LSTM, Linear
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Net(Module):
|
||||
def __init__(self, config):
|
||||
super(Net, self).__init__()
|
||||
self.lstm = LSTM(input_size=config.input_size, hidden_size=config.hidden_size,
|
||||
num_layers=config.lstm_layers, batch_first=True, dropout=config.dropout_rate)
|
||||
self.linear = Linear(in_features=config.hidden_size, out_features=config.output_size)
|
||||
|
||||
def forward(self, x, hidden=None):
|
||||
lstm_out, hidden = self.lstm(x, hidden)
|
||||
linear_out = self.linear(lstm_out)
|
||||
return linear_out, hidden
|
||||
|
||||
|
||||
def train(config, train_X, train_Y, valid_X, valid_Y):
|
||||
train_X, train_Y = torch.from_numpy(train_X).float(), torch.from_numpy(train_Y).float()
|
||||
train_loader = DataLoader(TensorDataset(train_X, train_Y), batch_size=config.batch_size)
|
||||
|
||||
valid_X, valid_Y = torch.from_numpy(valid_X).float(), torch.from_numpy(valid_Y).float()
|
||||
valid_loader = DataLoader(TensorDataset(valid_X, valid_Y), batch_size=config.batch_size)
|
||||
|
||||
model = Net(config)
|
||||
if config.add_train:
|
||||
model.load_state_dict(torch.load(config.model_save_path + config.model_name))
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
valid_loss_min = float("inf")
|
||||
bad_epoch = 0
|
||||
for epoch in range(config.epoch):
|
||||
print("Epoch {}/{}".format(epoch, config.epoch))
|
||||
model.train()
|
||||
train_loss_array = []
|
||||
hidden_train = None
|
||||
for i, _data in enumerate(train_loader):
|
||||
_train_X, _train_Y = _data
|
||||
optimizer.zero_grad()
|
||||
pred_Y, hidden_train = model(_train_X, hidden_train)
|
||||
|
||||
if not config.do_continue_train:
|
||||
hidden_train = None
|
||||
else:
|
||||
h_0, c_0 = hidden_train
|
||||
h_0.detach_(), c_0.detach_()
|
||||
hidden_train = (h_0, c_0)
|
||||
loss = criterion(pred_Y, _train_Y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
train_loss_array.append(loss.item())
|
||||
|
||||
model.eval()
|
||||
valid_loss_array = []
|
||||
hidden_valid = None
|
||||
for _valid_X, _valid_Y in valid_loader:
|
||||
pred_Y, hidden_valid = model(_valid_X, hidden_valid)
|
||||
if not config.do_continue_train: hidden_valid = None
|
||||
loss = criterion(pred_Y, _valid_Y)
|
||||
valid_loss_array.append(loss.item())
|
||||
|
||||
valid_loss_cur = np.mean(valid_loss_array)
|
||||
print("The train loss is {:.4f}. ".format(np.mean(train_loss_array)),
|
||||
"The valid loss is {:.4f}.".format(valid_loss_cur))
|
||||
|
||||
if valid_loss_cur < valid_loss_min:
|
||||
valid_loss_min = valid_loss_cur
|
||||
bad_epoch = 0
|
||||
torch.save(model.state_dict(), config.model_save_path + config.model_name)
|
||||
else:
|
||||
bad_epoch += 1
|
||||
if bad_epoch >= config.patience:
|
||||
print(" The training stops early in epoch {}".format(epoch))
|
||||
break
|
||||
|
||||
|
||||
def predict(config, test_X):
|
||||
test_X = torch.from_numpy(test_X).float()
|
||||
test_set = TensorDataset(test_X)
|
||||
test_loader = DataLoader(test_set, batch_size=1)
|
||||
|
||||
model = Net(config)
|
||||
model.load_state_dict(torch.load(config.model_save_path + config.model_name))
|
||||
|
||||
result = torch.Tensor()
|
||||
|
||||
model.eval()
|
||||
hidden_predict = None
|
||||
for _data in test_loader:
|
||||
data_X = _data[0]
|
||||
pred_X, hidden_predict = model(data_X, hidden_predict)
|
||||
cur_pred = torch.squeeze(pred_X, dim=0)
|
||||
result = torch.cat((result, cur_pred), dim=0)
|
||||
|
||||
return result.detach().numpy()
|
||||
Reference in New Issue
Block a user