오늘은 경사하강법을 직접 구현해보겠다. 학교 과제로 진행하게 되었지만 정리할겸 글을 남긴다.
모델은 간단한 선형 모델을 만들어보는걸 목표로 할 것이다!
14.3,21.6
5.3,11.2
9.2,19.1
11,21.1
9.9,18.1
14.9,23.3
11.6,21.9
8,17.4
13.1,22.5
14.8,23.2
5.7,12.5
8.2,16.6
7.2,15.2
10,18.7
9.1,17.2
13,21.6
10.3,19.3
5.9,12.2
6.1,12.8
15,22.4
10.3,21.3
15,21.6
11.3,22.1
8,16.4
11.8,22.4
테스트 케이스는 위와 같다. 왼쪽이 입력, 오른쪽이 출력이다. 편의상 출생 개월과 키라고 가정해 보자.
위 테스트 케이스를 csv파일로 저장한다음, numpy배열로 저장하기 위해 pandas 모듈을 사용할 것이다! 잘 이해가 되지않는다면 아래 코드를 우선 따라해 보자.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
raw_data = pd.read_csv('linear_regression_data01.csv', names =['age', 'tall'])
tall_data = np.asarray(raw_data['tall'].values.tolist()) # 키 데이터(Yn)
age_data = np.asarray(raw_data['age'].values.tolist()) # 나이 데이터(Xn)
tall_data.sort() # 데이터 정렬
age_data.sort()
# 실습1
plt.scatter(age_data, tall_data) # 데이터 위치를 점으로 표시
plt.title('HomeWork #1')
plt.xlabel('age')
plt.ylabel('tall')
plt.show()
먼저 위 데이터를 표로 만들어서 분석해 보자.
그 다음, 경사하강법을 구현해보기 전에 해석해를 통한 선형 모델을 구현해볼 생각이다.
결국 우리가 원하는 모델은 직선이다. 즉, y = ax + b 꼴로 나오는 직선중 데이터의 점들과 오차가 가장 적은 점을 탖는 문제인 것이다.
이때 위 식을 이용하면 해석해를 구할 수 있다. 물론 해당 방법은 아주 간단한 선형 모델에만 적용이 가능하다.
# 실습2
# 강의자료 수식을 코드로 표현
N =len(age_data)
Xn = sum(age_data)/N # Xn이 두번반복되기에 한번만 계산하도록 저장
W0 = (sum ( tall_data * (age_data -Xn) )/N) / ((sum( age_data **2 ) /N) - Xn **2) #분자
W1 = sum( tall_data - (W0 *age_data) ) /N # 분모
y = W0 * age_data + W1 # 해석해로 구한 선형모델
print('---------------------------------')
print('실습 #2')
print('W0 :', W0)
print('W1 :', W1)
해당 수식을 코드로 구현한 위 식을 통해 구한 값은 위와 같다.
# 실습3
# 해석해로 구한 w0,w1을 이용해 식을 출력
plt.grid(True, linestyle ='--')
plt.plot(age_data, y, label='linear-model', color ='limegreen')
plt.xlabel('age')
plt.ylabel('tall')
plt.scatter(age_data, tall_data, label='Data')
plt.legend()
plt.title('HomeWork #3')
plt.show()
이를 그래프로 같이 출력해본다면
위와같은 선형 모델이 나온다. 우리는 이 간단한 선형모델로 나이에 따른 키가 어떻게 될지 대략적으로 예측할 수 있게 된 것이다!
그렇다면 경사하강법을 통하여 어떻게 구현할 수 있을까?
좀 전에 우리는 가장 이상적인 선형모델을 찾기 위해선 ax + b꼴로 된 식을 찾는것이 목표라고 했다. 즉, a와 b값을 오차가 가장 적은 방향으로 조금씩 옮기다보면 정확한 값을 찾을 수 있다는 말이다.
경사하강법을 통하여 a와 b의 값을 조정하면서 정확한 값을 찾을 수 있다.
그전에 오차에 대해 조금 더 알아보자.
위 식에서 y햇은 결과이다. 즉 w0x + w1 꼴로된 우리의 예상결과와 실제결과인 yn의 차의 제곱의 평균이 얼마나 정확하게 예측을 했는지 지표로 사용할수 있는 평균제곱오차 즉, MSE라 불리는것이 된다.
코드로 만든다면 아래와 같다.
# 실습4
# MSE구하기
MSE = sum( (y -tall_data)**2 ) / N
print('---------------------------------')
print('실습 #4')
print('MSE :', MSE)
경사하강법은 위와같이 표현할 수 있는데, 해당 식의 의미를 곱씹어보자.
w0의 값과 그 MSE의 기울기값 즉, 오차가 얼마나 줄어들었거나 늘었는지를 파악하여 그 반대방향으로 오차를 줄여나가며 곱한다면 항아리에서 가장 오목한 부분을 찾듯 값을 찾을 수 있을 것이다.
# 실습5
# 머신러닝 교재 6p 식 활용
start_W0, start_W1 =3, 5 # 시작 w0,w1
cu_W0, cu_W1 = start_W0, start_W1 # 현재 w0,w1 값을 저장할 공간
lr =0.015 # learning rate
W0s, W1s = [cu_W0], [cu_W1] # 후에 그래프를 그리기위해 w0,w1를 저장할 공간
MSEs = [] # MSE를 저장해둘 고간
for i in range(5000): # 5000번 반복
y_pred = cu_W0 * age_data + cu_W1
error = tall_data - y_pred
error_mean = sum(error **2) / N # MSE값
a_diff =-(1 / N) * sum(age_data * (error)) # 기울기
b_diff =-(1 / N) * sum((error)) # 기울기값
cu_W0 = cu_W0 - lr * a_diff
cu_W1 = cu_W1 - lr * b_diff
if i % 500 ==0: # 그래프를 그리기위해 500번 반복시마다 값을 저장
W0s.append(cu_W0)
W1s.append(cu_W1)
MSEs.append(error_mean)
# 실습6
# 저장했던 값들 출력
print('---------------------------------')
print('실습 #6')
print('학습률 : ', lr)
print('초기값 : w0=', start_W0, ', w1 =', start_W1)
print('반복횟수 : 5000')
print('최종 평균제곱오차 :', error_mean)
print('최적 매개변수 : w0 =', cu_W0, ', w1 =', cu_W1)
print('---------------------------------')
이때 초기값은 보통 랜덤으로 지정해주나 본인은 생각나는대로 지정해줬고, 반복횟수는 5000, learning rate 즉 위에 식에서 알팍값에 해당하는 값은 0.015를 주었다. 5천번의 학습이 끝난 후 나온 w0과 w1의 값은 해석해로 구한 방법과 거의 일치하는 것을 볼 수 있다.
참 신기한게 w0과 w1을 어떠한 값을 주어도 학습을 반복하면 위 값으로 돌아오게 된다!
# 실습7
plt.grid(True, linestyle='--')
plt.xlabel('step')
plt.plot(range(0, len(W0s) * 500, 500), W0s, label='W1') # 500번마다 추출했기에 x좌표도 수정을 해주어야 함
plt.plot(range(0, len(W1s) * 500, 500), W1s, label='W0')
plt.scatter(range(0, len(W0s) * 500, 500), W0s)
plt.scatter(range(0, len(W1s) * 500, 500), W1s)
plt.plot(range(500, len(MSEs) * 500, 500), MSEs[1:], label='MSE')
plt.legend()
plt.title('HomeWork #7')
plt.show()
위 코드를 사용해서 그래프를 출력해보면 어떠한 과정을 통해 자리를 잡아가는지 볼 수 있다.
학습이 반복될수록 w1,w0값에 근접하게 가까워지는것을 알 수 있다.
'Python > 인공지능' 카테고리의 다른 글
가우스 함수를 이용한 선형회귀 모델 직접구현 (0) | 2022.04.17 |
---|---|
다중차원 선형모델 직접 구현 (0) | 2022.04.17 |
10_Text Detection(문자감지) (0) | 2022.03.26 |
09_순환신경망 RNN (0) | 2022.03.19 |
08_개,고양이 구분 인공지능 구현 (0) | 2022.03.18 |