Skip to content

Instantly share code, notes, and snippets.

@trueroad
Last active January 20, 2026 05:02
Show Gist options
  • Select an option

  • Save trueroad/a704a7ab54a851c0ddfbb359b2e7f87a to your computer and use it in GitHub Desktop.

Select an option

Save trueroad/a704a7ab54a851c0ddfbb359b2e7f87a to your computer and use it in GitHub Desktop.
Calc SMF (Standard MIDI File) Distance.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Calc SMF (Standard MIDI File) Distance.
https://gist.github.com/trueroad/a704a7ab54a851c0ddfbb359b2e7f87a
Copyright (C) 2026 Masamichi Hosoda.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED.
IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
SUCH DAMAGE.
"""
import sys
from typing import Any, Final, Optional
# https://gist.github.com/trueroad/52b7c4c98eec5fdf0ff3f62d64ec17bd
import smf_parse
# https://gist.github.com/trueroad/97477dab8beca099afeb4af5199634e2
import smf_diff
VERSION: Final[str] = '20260120.01'
class Distance:
"""
SMF距離(非類似度)計算クラス.
距離計算方式は下記論文に基づく
細田真道, 最知庸, 小林丈之, 笹生恵理, 山内竣平, 野口啓之, 阪内澄宇:
「ピアノ宿題練習のためのAI採点方式」, FIT2022, No. CE-007, 2022.
"""
def __init__(self,
verbose: int = 0) -> None:
"""
__init__.
Args:
verbose (int): Verbose レベル
"""
# Verbose レベル
self.verbose: Final[int] = verbose
# 係数のデフォルト値
# 詳細は論文を参照のこと
self.a_missing: float = 1.0
self.a_extra: float = 1.0
self.a_previous_r: float = 0.25
self.a_previous_m: float = 0.25
self.a_duration_r: float = 0.25
self.a_duration_m: float = 0.25
self.a_velocity_r: float = 0.01
self.a_velocity_m: float = 0.01
def __print_v(self,
*args: Any, level: int = 1, **kwargs: Any) -> None:
"""
Verbose レベルを考慮して print する.
Args:
*args (Any): print するもの
level (int): 出力する最低の verbose レベル
**kwargs (Any): print に渡すオプション引数
"""
if self.verbose >= level:
print(*args, **kwargs)
def calc(self,
sd: smf_diff.smf_difference,
begin: Optional[smf_parse.mbt_container] = None,
end: Optional[smf_parse.mbt_container] = None,
) -> float:
"""
差分比較クラスのインスタンスにあるSMF同士の距離(非類似度)計算をする.
Args:
sd (smf_diff.smf_difference): 差分比較クラスのインスタンス
begin (Optional[smf_parse.mbt_container]): 区間の先頭を指定する。
モデルのMBTで指定する。指定されたMBTは区間に含まれる。
NoneはSMFの最初からを意味する。
end (Optional[smf_parse.mbt_container]): 区間の最後の後を指定する。
モデルのMBTで指定する。指定されたMBTは区間に含まれない。
NoneはSMFの最後までを意味する。
Returns:
float: 距離
"""
# 計算方法
# 係数a_missing×(欠落した音符数÷モデルの音符数)
# +係数a_extra×(余計な音符数÷モデルの音符数)
# +係数a_previous_r×前の音符との時間差のRMSPE
# +係数a_previous_m×前の音符との時間差のMAPE
# +係数a_duration_r×音の長さのRMSPE
# +係数a_duration_m×音の長さのMAPE
# +係数a_velocity_r×ベロシティのRMSE
# +係数a_velocity_m×ベロシティのMAE
# (ただしマッチングが取れた音符数がゼロの場合は無条件でinfを返す)
# 性質:
# 数値が小さい方が近い
# 完全一致は0.0になり、これが最小
model_diff_len: int = len(sd.get_model_note_by_range(begin, end))
if model_diff_len == 0:
# 区間内にノートONがあるモデルの音符が一つも無い
self.__print_v(f'Error: model_diff_len = 0')
raise ValueError(f'model_diff_len = 0')
if len(sd.get_note_timing_by_range(begin, end)) == 0:
# マッチングが取れた音符数が一つも無いのでinfを返す
self.__print_v('Warning: no matching notes: '
'Distance is assumed to be infinite')
return float('inf')
distance: float \
= ((self.a_missing * len(sd.get_missing_note_by_range(begin, end))
/ model_diff_len)
+ (self.a_extra * len(sd.get_extra_note_by_range(begin, end))
/ model_diff_len)
+ self.a_previous_r * sd.calc_previous_rmspe(begin, end)
+ self.a_previous_m * sd.calc_previous_mape(begin, end)
+ self.a_duration_r * sd.calc_duration_rmspe(begin, end)
+ self.a_duration_m * sd.calc_duration_mape(begin, end)
+ self.a_velocity_r * sd.calc_velocity_rmse(begin, end)
+ self.a_velocity_m * sd.calc_velocity_mae(begin, end))
return distance
def main() -> None:
"""テスト用メイン."""
print(f'Calc SMF (Standard MIDI File) Distance {VERSION}\n\n'
'https://gist.github.com/trueroad/'
'a704a7ab54a851c0ddfbb359b2e7f87a\n\n'
'Copyright (C) 2026 Masamichi Hosoda.\n'
'All rights reserved.\n')
import argparse
parser: argparse.ArgumentParser = argparse.ArgumentParser()
# 入力ファイル
parser.add_argument('MODEL.mid', help='Model SMF.')
parser.add_argument('FOR_EVAL.mid', help='For-eval SMF.')
# 差分比較パラメータ(デフォルト値として論文の値を設定)
parser.add_argument('--max-misalignment',
help='t_misalignment: '
'Maximum time misalignment considered '
'simultaneous note on.',
type=float, default=0.05, required=False)
parser.add_argument('--filter-velocity',
help='th_velocity: '
'Filters out notes with less velocity '
'than specified from for-eval SMF.',
type=int, default=10, required=False)
parser.add_argument('--filter-duration',
help='t_duration: '
'Filters out notes with less duration '
'than specified from for-eval SMF (unit sec).',
type=float, default=0.05, required=False)
# 論文に登場しない差分比較パラメータ(デフォルトは論文の条件と同じ)
parser.add_argument('--filter-noteno-margin',
help='Filters out notes outside of '
'the model SMF note range plus or minus '
'this setting margin range.',
type=int, default=128, required=False)
parser.add_argument('--octave-reduction',
help='Enable octave reduction.',
action='store_true')
parser.add_argument('--strict-diff',
help='Enable strict diff.',
action='store_true')
args: argparse.Namespace = parser.parse_args()
vargs: dict[str, Any] = vars(args)
model_filename: str = vargs['MODEL.mid']
foreval_filename: str = vargs['FOR_EVAL.mid']
max_misalignment: float = vargs['max_misalignment']
filter_velocity: int = vargs['filter_velocity']
filter_duration: float = vargs['filter_duration']
filter_noteno_margin: int = vargs['filter_noteno_margin']
b_octave_reduction: bool = vargs['octave_reduction']
b_strict_diff: bool = vargs['strict_diff']
print(f'Model SMF : {model_filename}\n'
f'For-eval SMF : {foreval_filename}\n'
f'max-misalignment : {max_misalignment}\n'
f'filter-velocity : {filter_velocity}\n'
f'filter-duration : {filter_duration}\n'
f'filter-noteno-margin: {filter_noteno_margin}\n'
f'Octave reduction : {b_octave_reduction}\n'
f'Strict diff : {b_strict_diff}\n')
# 比較クラス
sd: smf_diff.smf_difference \
= smf_diff.smf_difference(verbose=1,
max_misalignment=max_misalignment,
filter_velocity=filter_velocity,
filter_duration=filter_duration,
filter_noteno_margin=filter_noteno_margin,
b_octave_reduction=b_octave_reduction,
b_strict_diff=b_strict_diff)
# モデルをロード
if not sd.load_model(model_filename):
return
# 評価対象をロード
if not sd.load_foreval(foreval_filename):
return
print('\nDiff notes')
# 差分をとる
sd.diff()
# タイミング系の集計
sd.calc_note_timing()
# 距離計算
dist: Distance = Distance()
d: float = dist.calc(sd)
print(f'\nDistance: {d}')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment