diff --git a/baler/modules/training.py b/baler/modules/training.py index ca189def..e5318f75 100644 --- a/baler/modules/training.py +++ b/baler/modules/training.py @@ -97,7 +97,7 @@ def fit( running_loss += loss.item() epoch_loss = running_loss / (idx + 1) - print(f"# Finished. Training Loss: {loss:.6f}") + print(f"# Finished. Training Loss: {epoch_loss:.6f}") return epoch_loss, mse_loss, l1_loss, model @@ -133,7 +133,7 @@ def validate(model, test_dl, model_children, reg_param): running_loss += loss.item() epoch_loss = running_loss / (idx + 1) - print(f"# Finished. Validation Loss: {loss:.6f}") + print(f"# Finished. Validation Loss: {epoch_loss:.6f}") return epoch_loss