【テンソル操作のキホン】reshape・squeeze・unsqueezeを完全マスター!📦➡️📐
機械学習・深層学習の実装において、避けて通れないのが**「テンソルの形状(shape)変換」**です。
この操作がうまく使えるようになると、モデルへの入力整形やデータ処理が一気にスマートに!
この記事では、代表的な3つの操作――
-
reshape
-
squeeze
-
unsqueeze
を中心に、図解的なイメージとPythonコードでわかりやすく解説します😊✨
🔄 reshape:形を自由に変える!
import numpy as np
a = np.array([[1, 2, 3],
[4, 5, 6]]) # shape: (2, 3)
b = a.reshape(3, 2) # shape: (3, 2)
出力:
[[1 2]
[3 4]
[5 6]]
📌 ポイント:
-
要素数(個数)が変わらない限り、自由に形を変えられる
-
-1
を使えば「自動計算」もできる!
a.reshape(-1, 1) # 自動的に列1の形にしてくれる
これは画像や時系列データのバッチ整形などに超便利です💡
✂️ squeeze:余分な次元を取り除く!
a = np.array([[[1], [2], [3]]]) # shape: (1, 3, 1)
b = np.squeeze(a) # shape: (3,)
📌 squeezeの役割:
-
次元が1の箇所(例:
(1, 3, 1)
の「1」部分)を削除してくれます
これは、余計な次元が残っていて演算できないときの救世主🛠️
例:
a.shape → (1, 28, 28, 1) # 画像処理などによくある形
np.squeeze(a).shape → (28, 28)
➕ unsqueeze:次元を追加する!
a = np.array([1, 2, 3]) # shape: (3,)
b = np.expand_dims(a, axis=0) # shape: (1, 3)
c = np.expand_dims(a, axis=1) # shape: (3, 1)
📌 unsqueeze(=expand_dims)の役割:
-
特定の軸に新しい次元(= 1)を追加する!
これはバッチ処理やCNNの入力整形時によく使われるんです。
🧪 PyTorchでの形状変換
import torch
x = torch.tensor([1, 2, 3])
# reshape
x_reshape = x.reshape(3, 1)
# unsqueeze
x_unsq = x.unsqueeze(0) # → shape: (1, 3)
# squeeze
x_sq = x_unsq.squeeze() # → shape: (3,)
📝 PyTorchでは .view()
という関数も reshape
とほぼ同じ意味で使えますが、reshape
の方が安全で推奨されます。
💥 よくある使いどころ
📌 画像入力のチャンネル追加
# (28, 28) → (1, 28, 28)
image = torch.tensor(image_array).unsqueeze(0)
📌 バッチ整形
# (28, 28) → (1, 1, 28, 28)
image = image.unsqueeze(0).unsqueeze(0)
📌 モデルの出力から余計な次元削除
output = output.squeeze() # → shapeを (1,) → () に
⚠️ 注意点
-
reshapeやviewは「要素数を変えない」のが前提!
-
squeezeは次元が1でないと消せない
-
unsqueeze後の処理に注意! → 入力shapeが変わるとモデルがエラーになることも
✅ まとめ
-
reshape
:形を変える(要素数はそのまま) -
squeeze
:不要な次元を削除(次元=1) -
unsqueeze
:新しい次元を追加(次元=1)
テンソルの形状変換を自在に扱えるようになれば、モデルの前処理・後処理がスムーズになり、開発効率が格段にアップ!🚀
次回は「テンソルの結合・分割(concatenate / stack / split / chunk)」について解説予定!
お楽しみに📚✨