Логистическая регрессия - это алгоритм машинного обучения, который моделирует вероятность события на основе входных переменных. Он широко используется в задачах бинарной классификации.
В этой статье мы создадим модель логистической регрессии с помощью PyTorch, библиотеки глубокого обучения с открытым исходным кодом.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
В качестве примера набора данных мы используем набор данных Iris.
data = pd.read_csv('iris.csv')
data = data.drop('species_id', axis=1)
X_train, X_test, y_train, y_test = train_test_split(data.drop('species', axis=1), data['species'], test_size=0.2)
y_train = y_train.astype('category').cat.codes
y_test = y_test.astype('category').cat.codes
train_data = DataLoader(torch.utils.data.TensorDataset(X_train.values, y_train.values), batch_size=64)
test_data = DataLoader(torch.utils.data.TensorDataset(X_test.values, y_test.values), batch_size=64)
class LogisticRegression(nn.Module):
def __init__(self, input_dim, output_dim):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return F.log_softmax(self.linear(x), dim=1)
model = LogisticRegression(X_train.shape[1], 3)
optimizer = optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
for x, y in train_data:
optimizer.zero_grad()
logits = model(x)
loss = F.nll_loss(logits, y)
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch + 1}/100], Loss: {loss.item()}')
correct_count, total_count = 0, 0
with torch.no_grad():
for x, y in test_data:
logits = model(x)
pred = logits.argmax(dim=1)
correct_count += (pred == y).sum().item()
total_count += y.size(0)
print(f'Accuracy: {correct_count / total_count:.2%}')
Мы успешно создали и обучили модель логистической регрессии с помощью PyTorch. Достигнутая нами точность составляет около 98%, что указывает на хорошее качество модели.
Этот пример можно расширить, включив в него больше функций и поэкспериментировав с различными гиперпараметрами.