Make sampler priority high if unspecified

This commit is contained in:
oobabooga 2024-09-29 20:45:27 -07:00
parent 01362681f2
commit bbdeed3cf4

View File

@ -602,11 +602,10 @@ def get_logits_processor_patch(self, **kwargs):
def custom_sort_key(obj): def custom_sort_key(obj):
class_name = obj.__class__.__name__ class_name = obj.__class__.__name__
# Return a large value if class name is not mapped or if the mapped nickname is not in priority # Return -1 if class_name is not mapped
if class_name not in class_name_to_nickname or class_name_to_nickname[class_name] not in sampler_priority: if class_name not in class_name_to_nickname or class_name_to_nickname[class_name] not in sampler_priority:
return float('inf') return -1
# Return the index of the nickname in the priority list for sorting
return sampler_priority.index(class_name_to_nickname[class_name]) return sampler_priority.index(class_name_to_nickname[class_name])
# Sort the list using the custom key function # Sort the list using the custom key function