Skip to content

Instantly share code, notes, and snippets.

@tsuchm
Created January 5, 2026 01:51
Show Gist options
  • Select an option

  • Save tsuchm/edad8783e724b3a2f5949429e4f73fd5 to your computer and use it in GitHub Desktop.

Select an option

Save tsuchm/edad8783e724b3a2f5949429e4f73fd5 to your computer and use it in GitHub Desktop.
# Wrapper class to re-define the loss function everywhen a model is
# re-generated for hyper-parameter search.
class CustomCrossEncoderTrainer(CrossEncoderTrainer):
def __init__(self, *, loss_init=None, **kwargs):
self.__loss_init = loss_init
super().__init__(**kwargs)
def call_model_init(self, trial=None):
model = super().call_model_init(trial=trial)
self.loss = self.__loss_init(model)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment