XAI(eXplainable AI) - 설명하는 AI

설명 가능한 AI(eXplainable AI) - XAI

  • XAI는 인공지능의 행위와 도출한 결과를 사람이 이해할 수 있는 형태로 이를 설명하는 방법론과 분야를 일컫는다. 흔히 인공지능 기술은 복잡한 일련의 과정(딥러닝)을 통해 결론을 도출하나, 그 과정을 설명할 수 없는 블랙 박스로 여겨진다. XAI는 이를 해소 시킬 수 있는 개념으로 인공지능의 신뢰성을 높이는 역할하고 있습니다.

(1) CAM

1 - flatten 작업 직전 단계에서 이때까지 만들어진 중간 결과들(feature map)을 수집
2 - 중간 결과들에 대한 평균값을 구함
3 - 평균값과 최종 예측값 사이에서 한번 더 학습 -> 어떤 중간값이 최종 결정에 영향을 크게 줬는지 확인


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
!pip install tf-explain


import zipfile
zipfile.ZipFile('img.zip').extractall()

from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications import VGG16


from tf_explain.core.grad_cam import GradCAM
from tf_explain.core.occlusion_sensitivity import OcclusionSensitivity

import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
  • 원본 사진 파일 확인
1
2
3
4
5
6
7
8
9
10
11
12
13
print(glob.glob('*_0.jpg'))
# ['yawl_0.jpg', 'squirrel_monkey_0.jpg', 'persian_cat_0.jpg',
# 'maltese_0.jpg', 'grand_piano_0.jpg']

images_originals = []

for img_name in glob.glob("*_0.jpg"):
images_originals.append(mpimg.imread(img_name))

plt.figure(figsize = (20,20))
for i, img in enumerate(images_originals):
plt.subplot(5,5,i+1)
plt.imshow(img)
스크린샷 2023-01-10 오전 10 50 11
  • 이제 VGG16에서 이미지 분류된 결과를 통해 원본 사진을 왜 카테고리(input_list)로 분류하였는지를 확인하겠습니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
model = VGG16(weights="imagenet", include_top=True)


input_list = ["maltese", "persian_cat", "squirrel_monkey", "grand_piano", "yawl"]
imagenet_index = ["153", "283", "382", "579", "914"]

#gradient CAM 알고리즘으로 XAI 생성
explainer = GradCAM()


for li, i in zip(input_list, imagenet_index):
print(i)
img = (load_img(f'{li}_0.jpg', target_size=(224,224)))
img = img_to_array(img)
data = ([img], None)
# print(data)
# print('--'*50)
grid = explainer.explain(data, model, int(i))# 설명하는 ai 생성
explainer.save(grid, '.', f'./{li}_cam.jpg') #_cam.jpg파일이란 이름으로 저장
  • 저장된 사진을 확인해봅시다.
1
2
3
4
5
6
7
8
9
10
11
12
13

#gradient CAM 알고리즘이 적용된 이미지를 저장할 리스트 정의
images_cams = []

plt.figure(figsize=(20,20))

for img in glob.glob("*_cam.jpg"):
images_cams.append(mpimg.imread(img))

# 출력
for i, img in enumerate(images_cams):
plt.subplot(5,5,i+1)
plt.imshow(img)
스크린샷 2023-01-10 오전 11 01 44

(2) 이미지를 일부를 가려서, 가려진 일부가 이미지 분류하는데 있어서 어느 정도 영향을 줬는지 계산하는 방식

작성 중

Author

InhwanCho

Posted on

2023-01-10

Updated on

2023-01-10

Licensed under

Comments