SAMを試す (論文メモはこちら - nagataka/Read-a-Paper#55) SAM3を使うにはHugging Faceのレポジトリ上からアクセスリクエストが必要そうなので、一旦SAM2で実験。
使用可能なチェックポイントはgithub レポジトリに一覧がある。今回は tiny を試すことにする。
ただ、サンプルをそのまま試しても以下のようなエラーが出て動かない。おそらく、gitからクローンしてきた sam2 ディレクトリで作業する分には問題ないのだろうけど、今回自分は別のプロジェクトディレクトリから sam2 を読み込んで使っているので、Hydraの設定をいじらないといけないようだ。
調べるのが面倒だったので、こちらの issue を参考にして問題を解決した。
from hydra.core.global_hydra import GlobalHydra
from hydra import initialize, compose
# to use ones own paths clear hydra
GlobalHydra.instance().clear()
checkpoint = "sam2/checkpoints/sam2.1_hiera_tiny.pt"
model_cfg = "sam2.1_hiera_t.yaml"
#predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
image = Image.open('resources/img/truck.jpg')
image = np.array(image.convert("RGB"))
# https://github.com/facebookresearch/sam2/issues/701
with initialize(version_base=None, config_path='../sam2/configs/'):
sam2_model = build_sam2(model_cfg, checkpoint, apply_postprocessing=False, device='cpu')
predictor = SAM2ImagePredictor(sam2_model)トレーニング・ファインチューニングしたければできる仕組みが提供されているのもありがたい - https://github.com/facebookresearch/sam2/blob/main/training/README.md
AlphaEarth Foundationsを理解するために勉強を始めた。基盤モデルのトレーニング自体は経験があるので、衛生画像やその他センサーデータに慣れることや、そもそもの問題やモチベーションを深く理解することが自分にはまず大切かな。
TODO: Landsat (NASA) と Sentinel-2 (EU) のデータをGoogle Earth Engineで見てみる。
CLIP (というかSigLIP2) を自前のデータでファインチューニングしたくて方法を確認している。
open_clip を使う場合は https://github.com/mlfoundations/open_clip/tree/main/src/open_clip_train
data.py をみると CsvDataset というクラスがあるのでこれを使うのが一番簡単そう?
とりあえず pip install 'open_clip_torch[training]' でインストール。
シングルプロセスでトレーニングを行うサンプルコマンドは
python -m open_clip_train.main \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
--train-data="/path/to/train_data.csv" \
--val-data="/path/to/validation_data.csv" \
--csv-img-key filepath \
--csv-caption-key title \
--imagenet-val=/path/to/imagenet/root/val/ \
--warmup 10000 \
--batch-size=128 \
--lr=1e-3 \
--wd=0.1 \
--epochs=30 \
--workers=8 \
--model ViT-B-16-SigLIP2-256 \
--pretrained webli
パラメータは params.py に定義されているので、必要に応じて追加すれば良い。また、open_clip で利用可能なモデルアーキテクチャと pre-training に用いられたデータの情報が見たければ open_clip.list_pretrained() で簡単に確認できる。
>>> import open_clip
>>> print(open_clip.list_pretrained())
[('RN50', 'openai'), ('RN50', 'yfcc15m'), ('RN50', 'cc12m'), ('RN101', 'openai'), ('RN101', 'yfcc15m'), ('RN50x4', 'openai'), ('RN50x16', 'openai'), ('RN50x64', 'openai'), ('ViT-B-32', 'openai'), ('ViT-B-32', 'laion400m_e31'), ('ViT-B-32', 'laion400m_e32'), ('ViT-B-32', 'laion2b_e16'), ('ViT-B-32', 'laion2b_s34b_b79k'), ('ViT-B-32', 'datacomp_xl_s13b_b90k'), ('ViT-B-32', 'datacomp_m_s128m_b4k'), ('ViT-B-32', 'commonpool_m_clip_s128m_b4k'), ('ViT-B-32', 'commonpool_m_laion_s128m_b4k'), ('ViT-B-32', 'commonpool_m_image_s128m_b4k'), ('ViT-B-32', 'commonpool_m_text_s128m_b4k'), ('ViT-B-32', 'commonpool_m_basic_s128m_b4k'), ('ViT-B-32', 'commonpool_m_s128m_b4k'), ('ViT-B-32', 'datacomp_s_s13m_b4k'), ('ViT-B-32', 'commonpool_s_clip_s13m_b4k'), ('ViT-B-32', 'commonpool_s_laion_s13m_b4k'), ('ViT-B-32', 'commonpool_s_image_s13m_b4k'), ('ViT-B-32', 'commonpool_s_text_s13m_b4k'), ('ViT-B-32', 'commonpool_s_basic_s13m_b4k'), ('ViT-B-32', 'commonpool_s_s13m_b4k'), ('ViT-B-32', 'metaclip_400m'), ('ViT-B-32', 'metaclip_fullcc'), ('ViT-B-32-256', 'datacomp_s34b_b86k'), ('ViT-B-16', 'openai'), ('ViT-B-16', 'laion400m_e31'), ('ViT-B-16', 'laion400m_e32'), ('ViT-B-16', 'laion2b_s34b_b88k'), ('ViT-B-16', 'datacomp_xl_s13b_b90k'), ('ViT-B-16', 'datacomp_l_s1b_b8k'), ('ViT-B-16', 'commonpool_l_clip_s1b_b8k'), ('ViT-B-16', 'commonpool_l_laion_s1b_b8k'), ('ViT-B-16', 'commonpool_l_image_s1b_b8k'), ('ViT-B-16', 'commonpool_l_text_s1b_b8k'), ('ViT-B-16', 'commonpool_l_basic_s1b_b8k'), ('ViT-B-16', 'commonpool_l_s1b_b8k'), ('ViT-B-16', 'dfn2b'), ('ViT-B-16', 'metaclip_400m'), ('ViT-B-16', 'metaclip_fullcc'), ('ViT-B-16-plus-240', 'laion400m_e31'), ('ViT-B-16-plus-240', 'laion400m_e32'), ('ViT-L-14', 'openai'), ('ViT-L-14', 'laion400m_e31'), ('ViT-L-14', 'laion400m_e32'), ('ViT-L-14', 'laion2b_s32b_b82k'), ('ViT-L-14', 'datacomp_xl_s13b_b90k'), ('ViT-L-14', 'commonpool_xl_clip_s13b_b90k'), ('ViT-L-14', 'commonpool_xl_laion_s13b_b90k'), ('ViT-L-14', 'commonpool_xl_s13b_b90k'), ('ViT-L-14', 'metaclip_400m'), ('ViT-L-14', 'metaclip_fullcc'), ('ViT-L-14', 'dfn2b'), ('ViT-L-14', 'dfn2b_s39b'), ('ViT-L-14-336', 'openai'), ('ViT-H-14', 'laion2b_s32b_b79k'), ('ViT-H-14', 'metaclip_fullcc'), ('ViT-H-14', 'metaclip_altogether'), ('ViT-H-14', 'dfn5b'), ('ViT-H-14-378', 'dfn5b'), ('ViT-g-14', 'laion2b_s12b_b42k'), ('ViT-g-14', 'laion2b_s34b_b88k'), ('ViT-bigG-14', 'laion2b_s39b_b160k'), ('ViT-bigG-14', 'metaclip_fullcc'), ('roberta-ViT-B-32', 'laion2b_s12b_b32k'), ('xlm-roberta-base-ViT-B-32', 'laion5b_s13b_b90k'), ('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'), ('convnext_base', 'laion400m_s13b_b51k'), ('convnext_base_w', 'laion2b_s13b_b82k'), ('convnext_base_w', 'laion2b_s13b_b82k_augreg'), ('convnext_base_w', 'laion_aesthetic_s13b_b82k'), ('convnext_base_w_320', 'laion_aesthetic_s13b_b82k'), ('convnext_base_w_320', 'laion_aesthetic_s13b_b82k_augreg'), ('convnext_large_d', 'laion2b_s26b_b102k_augreg'), ('convnext_large_d_320', 'laion2b_s29b_b131k_ft'), ('convnext_large_d_320', 'laion2b_s29b_b131k_ft_soup'), ('convnext_xxlarge', 'laion2b_s34b_b82k_augreg'), ('convnext_xxlarge', 'laion2b_s34b_b82k_augreg_rewind'), ('convnext_xxlarge', 'laion2b_s34b_b82k_augreg_soup'), ('coca_ViT-B-32', 'laion2b_s13b_b90k'), ('coca_ViT-B-32', 'mscoco_finetuned_laion2b_s13b_b90k'), ('coca_ViT-L-14', 'laion2b_s13b_b90k'), ('coca_ViT-L-14', 'mscoco_finetuned_laion2b_s13b_b90k'), ('EVA01-g-14', 'laion400m_s11b_b41k'), ('EVA01-g-14-plus', 'merged2b_s11b_b114k'), ('EVA02-B-16', 'merged2b_s8b_b131k'), ('EVA02-L-14', 'merged2b_s4b_b131k'), ('EVA02-L-14-336', 'merged2b_s6b_b61k'), ('EVA02-E-14', 'laion2b_s4b_b115k'), ('EVA02-E-14-plus', 'laion2b_s9b_b144k'), ('ViT-B-16-SigLIP', 'webli'), ('ViT-B-16-SigLIP-256', 'webli'), ('ViT-B-16-SigLIP-i18n-256', 'webli'), ('ViT-B-16-SigLIP-384', 'webli'), ('ViT-B-16-SigLIP-512', 'webli'), ('ViT-L-16-SigLIP-256', 'webli'), ('ViT-L-16-SigLIP-384', 'webli'), ('ViT-SO400M-14-SigLIP', 'webli'), ('ViT-SO400M-16-SigLIP-i18n-256', 'webli'), ('ViT-SO400M-14-SigLIP-378', 'webli'), ('ViT-SO400M-14-SigLIP-384', 'webli'), ('ViT-B-32-SigLIP2-256', 'webli'), ('ViT-B-16-SigLIP2', 'webli'), ('ViT-B-16-SigLIP2-256', 'webli'), ('ViT-B-16-SigLIP2-384', 'webli'), ('ViT-B-16-SigLIP2-512', 'webli'), ('ViT-L-16-SigLIP2-256', 'webli'), ('ViT-L-16-SigLIP2-384', 'webli'), ('ViT-L-16-SigLIP2-512', 'webli'), ('ViT-SO400M-14-SigLIP2', 'webli'), ('ViT-SO400M-14-SigLIP2-378', 'webli'), ('ViT-SO400M-16-SigLIP2-256', 'webli'), ('ViT-SO400M-16-SigLIP2-384', 'webli'), ('ViT-SO400M-16-SigLIP2-512', 'webli'), ('ViT-gopt-16-SigLIP2-256', 'webli'), ('ViT-gopt-16-SigLIP2-384', 'webli'), ('ViT-L-14-CLIPA', 'datacomp1b'), ('ViT-L-14-CLIPA-336', 'datacomp1b'), ('ViT-H-14-CLIPA', 'datacomp1b'), ('ViT-H-14-CLIPA-336', 'laion2b'), ('ViT-H-14-CLIPA-336', 'datacomp1b'), ('ViT-bigG-14-CLIPA', 'datacomp1b'), ('ViT-bigG-14-CLIPA-336', 'datacomp1b'), ('nllb-clip-base', 'v1'), ('nllb-clip-large', 'v1'), ('nllb-clip-base-siglip', 'v1'), ('nllb-clip-base-siglip', 'mrl'), ('nllb-clip-large-siglip', 'v1'), ('nllb-clip-large-siglip', 'mrl'), ('MobileCLIP-S1', 'datacompdr'), ('MobileCLIP-S2', 'datacompdr'), ('MobileCLIP-B', 'datacompdr'), ('MobileCLIP-B', 'datacompdr_lt'), ('MobileCLIP2-B', 'dfndr2b'), ('MobileCLIP2-S0', 'dfndr2b'), ('MobileCLIP2-S2', 'dfndr2b'), ('MobileCLIP2-S3', 'dfndr2b'), ('MobileCLIP2-S4', 'dfndr2b'), ('MobileCLIP2-L-14', 'dfndr2b'), ('ViTamin-S', 'datacomp1b'), ('ViTamin-S-LTT', 'datacomp1b'), ('ViTamin-B', 'datacomp1b'), ('ViTamin-B-LTT', 'datacomp1b'), ('ViTamin-L', 'datacomp1b'), ('ViTamin-L-256', 'datacomp1b'), ('ViTamin-L-336', 'datacomp1b'), ('ViTamin-L-384', 'datacomp1b'), ('ViTamin-L2', 'datacomp1b'), ('ViTamin-L2-256', 'datacomp1b'), ('ViTamin-L2-336', 'datacomp1b'), ('ViTamin-L2-384', 'datacomp1b'), ('ViTamin-XL-256', 'datacomp1b'), ('ViTamin-XL-336', 'datacomp1b'), ('ViTamin-XL-384', 'datacomp1b'), ('PE-Core-T-16-384', 'meta'), ('PE-Core-S-16-384', 'meta'), ('PE-Core-B-16', 'meta'), ('PE-Core-L-14-336', 'meta'), ('PE-Core-bigG-14-448', 'meta'), ('ViT-H-14-worldwide', 'metaclip2_worldwide'), ('ViT-H-14-worldwide-378', 'metaclip2_worldwide'), ('ViT-bigG-14-worldwide', 'metaclip2_worldwide'), ('ViT-bigG-14-worldwide-378', 'metaclip2_worldwide'), ('RN50-quickgelu', 'openai'), ('RN50-quickgelu', 'yfcc15m'), ('RN50-quickgelu', 'cc12m'), ('RN101-quickgelu', 'openai'), ('RN101-quickgelu', 'yfcc15m'), ('RN50x4-quickgelu', 'openai'), ('RN50x16-quickgelu', 'openai'), ('RN50x64-quickgelu', 'openai'), ('ViT-B-32-quickgelu', 'openai'), ('ViT-B-32-quickgelu', 'laion400m_e31'), ('ViT-B-32-quickgelu', 'laion400m_e32'), ('ViT-B-32-quickgelu', 'metaclip_400m'), ('ViT-B-32-quickgelu', 'metaclip_fullcc'), ('ViT-B-16-quickgelu', 'openai'), ('ViT-B-16-quickgelu', 'dfn2b'), ('ViT-B-16-quickgelu', 'metaclip_400m'), ('ViT-B-16-quickgelu', 'metaclip_fullcc'), ('ViT-L-14-quickgelu', 'openai'), ('ViT-L-14-quickgelu', 'metaclip_400m'), ('ViT-L-14-quickgelu', 'metaclip_fullcc'), ('ViT-L-14-quickgelu', 'dfn2b'), ('ViT-L-14-336-quickgelu', 'openai'), ('ViT-H-14-quickgelu', 'metaclip_fullcc'), ('ViT-H-14-quickgelu', 'dfn5b'), ('ViT-H-14-378-quickgelu', 'dfn5b'), ('ViT-bigG-14-quickgelu', 'metaclip_fullcc'), ('ViT-H-14-worldwide-quickgelu', 'metaclip2_worldwide')]
ファインチューニングの結果出力された重みを試すには、シンプルに以下で読み込める。
MODEL = 'ViT-B-16-SigLIP2-256'
model_path = <PATH TO A CHECKPOINT>
model, _, preprocess = open_clip.create_model_and_transforms(MODEL, pretrained=model_path)
HuggingFace Trainer でやる場合は Trainer クラスをオーバーライドして contrastive loss を計算するカスタムクラスを作る感じかな。
書いた - https://gist.github.com/nagataka/efde63affcbbdf6f0fdd7b5605d423e9