如果模型出现了 underfitting
问题,就得提高模型了。
举个例子,代码如下:
class CircleModelV1(nn.Module):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(in_features = 2, out_features = 10)
self.layer_2 = nn.Linear(in_features = 10, out_features = 10)
self.layer_3 = nn.Linear(in_features = 10, out_features = 1)
def forward(self, x):
return self.layer_3(self.layer_2(self.layer_1(x)))
model_1 = CircleModelV1().to("cpu")
print(model_1)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model_1.parameters(), lr=0.1)
torch.manual_seed(42)
epochs = 1000
X_train, y_train = X_train.to("cpu"), y_train.to("cpu")
X_test, y_test = X_test.to("cpu"), y_test.to("cpu")
for epoch in range(epochs):
### Training
# 1. Forward pass
y_logits = model_1(X_train).squeeze()
y_pred = torch.round(torch.sigmoid(y_logits)) # logits -> probabilities -> prediction labels
# 2. Calculate loss/accuracy
loss = loss_fn(y_logits, y_train)
acc = accuracy_fn(y_true = y_train, y_pred = y_pred)
# 3. Optimizer zero grad
optimizer.zero_grad()
# 4. Loss backwards
loss.backward()
# 5. Optimizer step
optimizer.step()
### Testing
model_1.eval()
with torch.inference_mode():
# 1. Forward pass
test_logits = model_1(X_test).squeeze()
test_pred = torch.round(torch.sigmoid(test_logits))
# 2. Calculate loss/accuracy
test_loss = loss_fn(test_logits, y_test)
test_acc = accuracy_fn(y_true = y_test, y_pred = test_pred)
if epoch % 100 == 0:
print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {acc:.2f}%")
# 结果如下
CircleModelV1(
(layer_1): Linear(in_features=2, out_features=10, bias=True)
(layer_2): Linear(in_features=10, out_features=10, bias=True)
(layer_3): Linear(in_features=10, out_features=1, bias=True)
)
Epoch: 0 | Loss: 0.69528, Accuracy: 51.38%
Epoch: 100 | Loss: 0.69325, Accuracy: 47.88%
Epoch: 200 | Loss: 0.69309, Accuracy: 49.88%
Epoch: 300 | Loss: 0.69303, Accuracy: 50.50%
Epoch: 400 | Loss: 0.69300, Accuracy: 51.38%
Epoch: 500 | Loss: 0.69299, Accuracy: 51.12%
Epoch: 600 | Loss: 0.69298, Accuracy: 51.50%
Epoch: 700 | Loss: 0.69298, Accuracy: 51.38%
Epoch: 800 | Loss: 0.69298, Accuracy: 51.50%
Epoch: 900 | Loss: 0.69298, Accuracy: 51.38%
都看到这了,点个赞呗~