Last active
January 20, 2026 05:02
-
-
Save trueroad/a704a7ab54a851c0ddfbb359b2e7f87a to your computer and use it in GitHub Desktop.
Calc SMF (Standard MIDI File) Distance.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Calc SMF (Standard MIDI File) Distance. | |
| https://gist.github.com/trueroad/a704a7ab54a851c0ddfbb359b2e7f87a | |
| Copyright (C) 2026 Masamichi Hosoda. | |
| All rights reserved. | |
| Redistribution and use in source and binary forms, with or without | |
| modification, are permitted provided that the following conditions | |
| are met: | |
| * Redistributions of source code must retain the above copyright notice, | |
| this list of conditions and the following disclaimer. | |
| * Redistributions in binary form must reproduce the above copyright notice, | |
| this list of conditions and the following disclaimer in the documentation | |
| and/or other materials provided with the distribution. | |
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |
| ARE DISCLAIMED. | |
| IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
| FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS | |
| OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) | |
| HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |
| LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |
| OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF | |
| SUCH DAMAGE. | |
| """ | |
| import sys | |
| from typing import Any, Final, Optional | |
| # https://gist.github.com/trueroad/52b7c4c98eec5fdf0ff3f62d64ec17bd | |
| import smf_parse | |
| # https://gist.github.com/trueroad/97477dab8beca099afeb4af5199634e2 | |
| import smf_diff | |
| VERSION: Final[str] = '20260120.01' | |
| class Distance: | |
| """ | |
| SMF距離(非類似度)計算クラス. | |
| 距離計算方式は下記論文に基づく | |
| 細田真道, 最知庸, 小林丈之, 笹生恵理, 山内竣平, 野口啓之, 阪内澄宇: | |
| 「ピアノ宿題練習のためのAI採点方式」, FIT2022, No. CE-007, 2022. | |
| """ | |
| def __init__(self, | |
| verbose: int = 0) -> None: | |
| """ | |
| __init__. | |
| Args: | |
| verbose (int): Verbose レベル | |
| """ | |
| # Verbose レベル | |
| self.verbose: Final[int] = verbose | |
| # 係数のデフォルト値 | |
| # 詳細は論文を参照のこと | |
| self.a_missing: float = 1.0 | |
| self.a_extra: float = 1.0 | |
| self.a_previous_r: float = 0.25 | |
| self.a_previous_m: float = 0.25 | |
| self.a_duration_r: float = 0.25 | |
| self.a_duration_m: float = 0.25 | |
| self.a_velocity_r: float = 0.01 | |
| self.a_velocity_m: float = 0.01 | |
| def __print_v(self, | |
| *args: Any, level: int = 1, **kwargs: Any) -> None: | |
| """ | |
| Verbose レベルを考慮して print する. | |
| Args: | |
| *args (Any): print するもの | |
| level (int): 出力する最低の verbose レベル | |
| **kwargs (Any): print に渡すオプション引数 | |
| """ | |
| if self.verbose >= level: | |
| print(*args, **kwargs) | |
| def calc(self, | |
| sd: smf_diff.smf_difference, | |
| begin: Optional[smf_parse.mbt_container] = None, | |
| end: Optional[smf_parse.mbt_container] = None, | |
| ) -> float: | |
| """ | |
| 差分比較クラスのインスタンスにあるSMF同士の距離(非類似度)計算をする. | |
| Args: | |
| sd (smf_diff.smf_difference): 差分比較クラスのインスタンス | |
| begin (Optional[smf_parse.mbt_container]): 区間の先頭を指定する。 | |
| モデルのMBTで指定する。指定されたMBTは区間に含まれる。 | |
| NoneはSMFの最初からを意味する。 | |
| end (Optional[smf_parse.mbt_container]): 区間の最後の後を指定する。 | |
| モデルのMBTで指定する。指定されたMBTは区間に含まれない。 | |
| NoneはSMFの最後までを意味する。 | |
| Returns: | |
| float: 距離 | |
| """ | |
| # 計算方法 | |
| # 係数a_missing×(欠落した音符数÷モデルの音符数) | |
| # +係数a_extra×(余計な音符数÷モデルの音符数) | |
| # +係数a_previous_r×前の音符との時間差のRMSPE | |
| # +係数a_previous_m×前の音符との時間差のMAPE | |
| # +係数a_duration_r×音の長さのRMSPE | |
| # +係数a_duration_m×音の長さのMAPE | |
| # +係数a_velocity_r×ベロシティのRMSE | |
| # +係数a_velocity_m×ベロシティのMAE | |
| # (ただしマッチングが取れた音符数がゼロの場合は無条件でinfを返す) | |
| # 性質: | |
| # 数値が小さい方が近い | |
| # 完全一致は0.0になり、これが最小 | |
| model_diff_len: int = len(sd.get_model_note_by_range(begin, end)) | |
| if model_diff_len == 0: | |
| # 区間内にノートONがあるモデルの音符が一つも無い | |
| self.__print_v(f'Error: model_diff_len = 0') | |
| raise ValueError(f'model_diff_len = 0') | |
| if len(sd.get_note_timing_by_range(begin, end)) == 0: | |
| # マッチングが取れた音符数が一つも無いのでinfを返す | |
| self.__print_v('Warning: no matching notes: ' | |
| 'Distance is assumed to be infinite') | |
| return float('inf') | |
| distance: float \ | |
| = ((self.a_missing * len(sd.get_missing_note_by_range(begin, end)) | |
| / model_diff_len) | |
| + (self.a_extra * len(sd.get_extra_note_by_range(begin, end)) | |
| / model_diff_len) | |
| + self.a_previous_r * sd.calc_previous_rmspe(begin, end) | |
| + self.a_previous_m * sd.calc_previous_mape(begin, end) | |
| + self.a_duration_r * sd.calc_duration_rmspe(begin, end) | |
| + self.a_duration_m * sd.calc_duration_mape(begin, end) | |
| + self.a_velocity_r * sd.calc_velocity_rmse(begin, end) | |
| + self.a_velocity_m * sd.calc_velocity_mae(begin, end)) | |
| return distance | |
| def main() -> None: | |
| """テスト用メイン.""" | |
| print(f'Calc SMF (Standard MIDI File) Distance {VERSION}\n\n' | |
| 'https://gist.github.com/trueroad/' | |
| 'a704a7ab54a851c0ddfbb359b2e7f87a\n\n' | |
| 'Copyright (C) 2026 Masamichi Hosoda.\n' | |
| 'All rights reserved.\n') | |
| import argparse | |
| parser: argparse.ArgumentParser = argparse.ArgumentParser() | |
| # 入力ファイル | |
| parser.add_argument('MODEL.mid', help='Model SMF.') | |
| parser.add_argument('FOR_EVAL.mid', help='For-eval SMF.') | |
| # 差分比較パラメータ(デフォルト値として論文の値を設定) | |
| parser.add_argument('--max-misalignment', | |
| help='t_misalignment: ' | |
| 'Maximum time misalignment considered ' | |
| 'simultaneous note on.', | |
| type=float, default=0.05, required=False) | |
| parser.add_argument('--filter-velocity', | |
| help='th_velocity: ' | |
| 'Filters out notes with less velocity ' | |
| 'than specified from for-eval SMF.', | |
| type=int, default=10, required=False) | |
| parser.add_argument('--filter-duration', | |
| help='t_duration: ' | |
| 'Filters out notes with less duration ' | |
| 'than specified from for-eval SMF (unit sec).', | |
| type=float, default=0.05, required=False) | |
| # 論文に登場しない差分比較パラメータ(デフォルトは論文の条件と同じ) | |
| parser.add_argument('--filter-noteno-margin', | |
| help='Filters out notes outside of ' | |
| 'the model SMF note range plus or minus ' | |
| 'this setting margin range.', | |
| type=int, default=128, required=False) | |
| parser.add_argument('--octave-reduction', | |
| help='Enable octave reduction.', | |
| action='store_true') | |
| parser.add_argument('--strict-diff', | |
| help='Enable strict diff.', | |
| action='store_true') | |
| args: argparse.Namespace = parser.parse_args() | |
| vargs: dict[str, Any] = vars(args) | |
| model_filename: str = vargs['MODEL.mid'] | |
| foreval_filename: str = vargs['FOR_EVAL.mid'] | |
| max_misalignment: float = vargs['max_misalignment'] | |
| filter_velocity: int = vargs['filter_velocity'] | |
| filter_duration: float = vargs['filter_duration'] | |
| filter_noteno_margin: int = vargs['filter_noteno_margin'] | |
| b_octave_reduction: bool = vargs['octave_reduction'] | |
| b_strict_diff: bool = vargs['strict_diff'] | |
| print(f'Model SMF : {model_filename}\n' | |
| f'For-eval SMF : {foreval_filename}\n' | |
| f'max-misalignment : {max_misalignment}\n' | |
| f'filter-velocity : {filter_velocity}\n' | |
| f'filter-duration : {filter_duration}\n' | |
| f'filter-noteno-margin: {filter_noteno_margin}\n' | |
| f'Octave reduction : {b_octave_reduction}\n' | |
| f'Strict diff : {b_strict_diff}\n') | |
| # 比較クラス | |
| sd: smf_diff.smf_difference \ | |
| = smf_diff.smf_difference(verbose=1, | |
| max_misalignment=max_misalignment, | |
| filter_velocity=filter_velocity, | |
| filter_duration=filter_duration, | |
| filter_noteno_margin=filter_noteno_margin, | |
| b_octave_reduction=b_octave_reduction, | |
| b_strict_diff=b_strict_diff) | |
| # モデルをロード | |
| if not sd.load_model(model_filename): | |
| return | |
| # 評価対象をロード | |
| if not sd.load_foreval(foreval_filename): | |
| return | |
| print('\nDiff notes') | |
| # 差分をとる | |
| sd.diff() | |
| # タイミング系の集計 | |
| sd.calc_note_timing() | |
| # 距離計算 | |
| dist: Distance = Distance() | |
| d: float = dist.calc(sd) | |
| print(f'\nDistance: {d}') | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment