상황:
a:numpy.ndarray, 2-dimensional,
a의 각 row마다 k개의 top values뽑기
ex)
a=np.random.randint(0,10,(4,5)) # 0에서 9까지 values중 random integers뽑아서 shape이 (4,5)인 것
k=3
res = a[np.arange(len(a))[:, None], np.argsort(a)[:, -k:]][:, ::-1]
설명:
(1)
np.argsort는 ascending order로 indices를 뽑아냄
(2)
the smallest top k를 구하려면 np.argsort(a)[:, :k], 그리고 뒤에 [:, ::-1] 없어야함
(3)
만약 the top k largest/smallest values 를 ordering없이 뽑으려면
np.argsort보다 np.partition이 낫다. 왜냐하면 np.argsort는 sorting(O(n*log(n))을 하지만 후자는 sorting하지 않음(O(n))
즉,
res = np.partition(a, -k)[:, -k:] 하면은 top k values(not ordered)를 얻을 수 있다.
(4)
the top k largest/smallest values의 indices를 ordering없이 뽑으려면
np.partition말고 np.argpartition을 활용하면 된다.
즉,
res = np.argparition(a, -k)[:, -k:]
(5)
a에 negation붙여서 하진 말 것, 그러면 new object를 만들어서 하기 때문에 memory 더 쓰게 됨
참고자료:
kanoki.org/2020/01/14/find-k-smallest-and-largest-values-and-its-indices-in-a-numpy-array/
'ML' 카테고리의 다른 글
[Numpy]ndarray가 (built-in)list보다 빠른 이유 (0) | 2020.11.02 |
---|---|
[Clustering]K-means Clustering (0) | 2020.10.29 |
[Numpy]numpy.ndarray에서 각 row마다 특정 column의 원소를 가져오고 싶을 때 (0) | 2020.10.25 |
[Numpy]numpy.ndarray 각 원소에 dictionary map할 때 (0) | 2020.10.25 |
[Pandas]중복인것만 살리기 (0) | 2020.10.25 |