numpy的维度增删函数——np.expand_dims()和np.squeeze()
·
numpy是python中重要的科学计算库,常用于数组或矩阵的计算,此时便会涉及到数组维度匹配问题,虽然numpy有broadcast机制,但为避免一些难以察觉的bug,有必要对数组的维度进行增删操作,以使数组维度相匹配。
numpy为用户提供了维度扩展函数np.expand_dims()和维度删减函数np.squeeze()。
np.expand_dims()
np.expand_dims(a, axis)
参数如下:
- a : array_like
- axis : int
Position in the expanded axes where the new axis is placed.
该函数的作用是在指定轴axis上增加数组a的一个维度,即,在第“axis”维,加一个维度出来,原先在“axis”左边的维度保持位置不变,在“axis”右边的维度整体右移。
注意:该函数不改变输入数组a,而是产生一个新数组,新数组中的元素与原数组完全相同。
假设三维数组a的shape是(m, n, c),则
- np.expand_dims(a, axis=0)表示在a的第一个维度上增加一个新的维度,而其他维度整体往右移,最终得到shape为(1, m, n, c)的新数组,新数组中的元素与原数组完全相同。
import numpy as np
a = np.reshape(list(range(24)), (2, 3, 4))
a_new = np.expand_dims(a, axis=0)
print('a =', a)
print('a_new =', a_new)
print('a.shape = ', a.shape)
print('a_new.shape = ', a_new.shape)
输出结果:
a = [[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
a_new = [[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]]
a.shape = (2, 3, 4)
a_new.shape = (1, 2, 3, 4)
- np.expand_dims(a, axis=1)将得到shape为(m, 1, n, c)的新数组,新数组中的元素与原数组a完全相同。
- np.expand_dims(a, axis=2)将得到shape为(m, n, 1, c)的新数组,新数组中的元素与原数组a完全相同。
- np.expand_dims(a, axis=3)将得到shape为(m, n, c, 1)的新数组,新数组中的元素与原数组a完全相同。
import numpy as np
a = np.reshape(list(range(24)), (2, 3, 4))
print('a =', a)
print('np.expand_dims(a, axis=1) =', np.expand_dims(a, axis=1))
print('np.expand_dims(a, axis=2) =', np.expand_dims(a, axis=2))
print('np.expand_dims(a, axis=3) =', np.expand_dims(a, axis=3))
print('a.shape = ', a.shape)
print('np.expand_dims(a, axis=1).shape =', np.expand_dims(a, axis=1).shape)
print('np.expand_dims(a, axis=2).shape =', np.expand_dims(a, axis=2).shape)
print('np.expand_dims(a, axis=3).shape =', np.expand_dims(a, axis=3).shape)
输出结果:
a = [[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
np.expand_dims(a, axis=1) = [[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]]
[[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]]
np.expand_dims(a, axis=2) = [[[[ 0 1 2 3]]
[[ 4 5 6 7]]
[[ 8 9 10 11]]]
[[[12 13 14 15]]
[[16 17 18 19]]
[[20 21 22 23]]]]
np.expand_dims(a, axis=3) = [[[[ 0]
[ 1]
[ 2]
[ 3]]
[[ 4]
[ 5]
[ 6]
[ 7]]
[[ 8]
[ 9]
[10]
[11]]]
[[[12]
[13]
[14]
[15]]
[[16]
[17]
[18]
[19]]
[[20]
[21]
[22]
[23]]]]
a.shape = (2, 3, 4)
np.expand_dims(a, axis=1).shape = (2, 1, 3, 4)
np.expand_dims(a, axis=2).shape = (2, 3, 1, 4)
np.expand_dims(a, axis=3).shape = (2, 3, 4, 1)
np.squeeze()
squeeze(a, axis=None)
参数:
- a : array_like
Input data. - axis : None or int or tuple of ints, optional
该函数的作用是:删除输入数组a中维度为1的维度,并返回新的数组,新数组的元素与原数组a完全相同。(Remove single-dimensional entries from the shape of an array.)
>>> a = np.array([[[0], [1], [2]]])
>>> a.shape
(1, 3, 1)
# 未指定axis,则删除所有维度为1的维度
>>> np.squeeze(a)
[0, 1, 2]
>>> np.squeeze(a).shape
(3,)
# 指定axis=0,则删除该维度
>>> np.squeeze(a, axis=0)
[[0]
[1]
[2]]
>>> np.squeeze(a, axis=0).shape
(3, 1)
# 指定axis=2,则删除该维度
>>> np.squeeze(a, axis=2)
[[0 1 2]]
>>> np.squeeze(a, axis=2).shape
(3, 1)
# 同时指定axis=0和axis=2,则删除这两个维度
>>> np.squeeze(a, axis=(0, 2))
[0 1 2]
>>> np.squeeze(a, axis=(0, 2)).shape
(3,)
# 对于指定的axis,其维度必定为1,否则会报错
>>> np.squeeze(a, axis=1).shape
Traceback (most recent call last):
...
ValueError: cannot select an axis to squeeze out which has size not equal to one
更多推荐
已为社区贡献5条内容
所有评论(0)