diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 1f02c39..655bc0b 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -2,7 +2,6 @@ - @@ -48,7 +47,7 @@ - + @@ -57,7 +56,7 @@ - + @@ -79,11 +78,11 @@ - - + + - + @@ -208,7 +207,7 @@ - + @@ -249,12 +248,12 @@ - @@ -369,11 +368,11 @@ - - + + - + diff --git a/serve.py b/serve.py index 0767392..3a50801 100644 --- a/serve.py +++ b/serve.py @@ -180,7 +180,11 @@ def train_models(): train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data() print(train_X, valid_X, train_Y, valid_Y) - print(train_X.shape) + print(train_X.shape[0]) + if train_X.shape[0] < 500: + config.batch_size = 32 + if train_X.shape[0] < 200: + config.batch_size = 16 train(config, train_X, train_Y, valid_X, valid_Y)