ML
[Numpy]numpy.ndarray에서 각 row마다 특정 column의 원소를 가져오고 싶을 때
프리랜서를꿈꾸는자
2020. 10. 25. 12:54
728x90
상황:
a = np.arange(20).reshape(4,5)
b = np.array([[0,1],[1,2],[2,3],[3,4]])
a의 first row에서는 [0,1]에 해당하는 entries를
a의 second row에서는 [1,2]에 해당하는 entries를
a의 third row에서는 [2,3]에 해당하는 entries를
a의 fourth row에서는 [3,4]에 해당하는 entries를
가져오고 싶다.
res = a[np.arange(4)[:, None], b]
설명
(1)
numpy.ndarray에 [:, None]을 하면 길이가 1인 axis를 하나 더 생성한다.
예를 들면
q = np.arange(10) # shape은 (10,)
w = q[:, None] # shape은 (10,1)
e = q[:, None] # shape은 (10, 1, 1)
reshape(-1,1)과 비슷해보이지만
w = q.reshape(-1,1) # (10,1)
e = w.reshape(-1,1) # (10,1), 즉 w와 shape이 같음
(2)
a의 third row 빼고 가져오고 싶다면
ind = np.array([0,1,3])
res = a[ind[:,None], b[ind]]
(3)
a의 first, third빼고 모두 가져오고 싶다면
ind = np.array([x for x in range(len(a)) if x not in [0,2]])
res = a[ind[:,None], b[ind]]
728x90