Przejdź do treści
Strona główna » Blog » Python – regresja liniowa

Python – regresja liniowa

Regresja liniowa to bardzo prosta i użyteczna metoda matematyczna, stosowana w analizie wszelkich danych. Mówiąc najprościej celem regresji liniowej jest znalezienie równania linii prostej, która najlepiej „pasuje” do zbioru danych x,y. Opis teoretyczny można znaleźć na wielu różnych stronach, między innymi na Wikipedii.

Zastosowanie regresji liniowej

Regresja liniowa jest szeroko stosowana w różnych dziedzinach, takich jak ekonomia, nauki społeczne, nauki przyrodnicze, inżynieria itp. O tym jak ważna jest ta metoda świadczy między innymi to, że możemy znaleźć dziesiątki książek poświęconych wyłącznie temu zagadnieniu. Regresja liniowa pojawia się niemal w każdym podręczniku do ekonometrii1. Również w badaniach naukowych jest ona wykorzystywana bardzo często. Naukowcy starają się znaleźć zależności pomiędzy zmiennymi, dzięki czemu mogą zrozumieć badane zagadnienie. W książce Linear Regression Analysis: Theory and Computing2 opisano eksperyment, w którym badano wpływ palenia papierosów na śmiertelność. Możemy również znaleźć inne ciekawe badania naukowe dotyczące zastosowania regresji liniowej. Przykładem może być praca naukowa3, w której wykorzystano tę metodę do określania wieku autora na podstawie pisanych przez niego tekstów.

Biblioteki Pythona do regresji liniowej

W języku Python mamy dwie główne biblioteki, które mogą być wykorzystane do regresji liniowej: scikit-learn oraz statsmodels. Druga z wymienionych jest dużo bardziej rozbudowana i nie będziemy mieli potrzeby korzystać z jej funkcjonalności. Biblioteka scikit-learn zawiera wszystkie kluczowe funkcje na potrzeby większości analiz.
Wykonanie regresji liniowej w Pythonie z użyciem bilbioteki scikitlearn jest bardzo proste. Oprócz tej biblioteki posłużymy się również:

  • numpy – biblioteka do obliczeń, w szczególności na macierzach,
  • matplotlib – biblioteka do tworzenia wykresów.

Program do regresji liniowej

Najpierw zaimportujemy wspomniane biblioteki:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

Do regresji wykorzystamy dane wygenerowane przy użyciu biblioteki numpy. Wektor X będzie zawierał losowe wartości z przedziału od 0 do 1, natomiast y to wartości obliczone z równania linii prostej, do których dodany będzie losowy „błąd”.

np.random.seed(0)
X = np.random.rand(100, 1)  # Zmienna niezależna
y = 3 * X + 2 + np.random.randn(100, 1)/10  # Zmienna zależna + "błąd"

W następnym kroku utworzymy model regresji liniowej. Metoda LinearRegression() została wcześniej zaimportowana z biblioteki sklearn. Aby dopasować linię prostą do danych X, y, wywołujemy metodę fit().

# Utworzenie modelu regresji liniowej i dopasowanie
model = LinearRegression()
model.fit(X, y)

Do oceny dopasania prostej posłużymy się współczynnikiem R2. Z teorii wiemy, że jego wartość powinna być możliwie bliska jedności, co świadczy o dobrym dopasowaniu prostej do danych. Jeżeli wartość jest znacznie mniejsza od jedności („znacznie” zależy od konkretnego problemu), to pomiędzy danymi X i y nie zachodzi zależność liniowa. Aby wyznaczyć R2 najpierw obliczamy wartości y przewidywane przez stworzony model regresji liniowej (metoda predict) dla istniejących wartości X. Następnie otrzymane wartości wykorzystujemy do wywołania metody r2score.

# Wartości przewidywane przez model regresji liniowej
y_pred = model.predict(X)

# Obliczenie R-kwadrat.
r2 = r2_score(y, y_pred)

print("R kwadrat:", r2)

Parametry prostej a*x + b otrzymujemy z atrybutów obiektu model. coef_ to parametr a (nachylenie), a intercept_ to b (przecięcie z osią OY).

# Współczynnik a w równaniu y = a*x + b
a = model.coef_
print("Współczynnik a:", a)

# Współczynnik b w równaniu y = a*x + b
b = model.intercept_
print("Współczynnik b:", b)

W ostatnim kroku zwizualizujemy dane i wyniki. Do tego celu posłuży nam biblioteka Matplotlib. Utworzymy wykres punktowy przy użyciu metody scatter. Metody xlabel i ylabel służą do opisania osi, legend() do utworzenia legendy, na której wyświetlone zostaną wcześniej nadane etykiety (label). Wykres wyświetlamy przy użyciu show().

plt.scatter(X, y, color='blue', label='Dane')
plt.plot(X, y_pred, color='red', linewidth=2, label='Regresja liniowa')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.show()
Python regresa liniowa - rysunek w Matplotlib
Regresja liniowa (R2 = 0,99)

#data science #machine learning #statystyka #analiza danych #python regresja liniowa

Regresję możesz wykonać przy użyciu naszego narzędzia online

Więcej na temat obliczeń z użyciem Pythona dowiesz się z naszych kursów programowania.

Podsumowanie

Z tego artykułu dowiedziałeś się czym jest regresja liniowa i jak napisać program, który tworzy model regresji liniowej w Pythonie. W szczególności nauczyłeś się jak:

  • wykorzystać bibliotekę scikit-learn do stworzenia modelu prostej regresji liniowej,
  • wyznaczyć współczynniki linii prostej oraz współczynnik dopasowania R2,
  • narysować wykres z danymi oraz dopasowaną prostą.

Odnośniki

  1. On Using Linear Regressions in Welfare Economics: Journal of Business & Economic Statistics: Vol 14, No 4 (tandfonline.com)
  2. Linear Regression Analysis: Theory and Computing | Guide books | ACM Digital Library
  3. Nguyen, Dong, Noah A. Smith, and Carolyn Rose. „Author age prediction from text using linear regression.” Proceedings of the 5th ACL-HLT workshop on language technology for cultural heritage, social sciences, and humanities. 2011.