Read rope_theta for DBRX model (thanks turboderp)

This commit is contained in:
oobabooga 2024-04-01 20:25:31 -07:00
parent db5f6cd1d8
commit 9ab7365b56

View File

@ -84,6 +84,8 @@ def get_model_metadata(model):
if 'rope_theta' in metadata: if 'rope_theta' in metadata:
model_settings['rope_freq_base'] = metadata['rope_theta'] model_settings['rope_freq_base'] = metadata['rope_theta']
elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']:
model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta']
if 'rope_scaling' in metadata and type(metadata['rope_scaling']) is dict and all(key in metadata['rope_scaling'] for key in ('type', 'factor')): if 'rope_scaling' in metadata and type(metadata['rope_scaling']) is dict and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
if metadata['rope_scaling']['type'] == 'linear': if metadata['rope_scaling']['type'] == 'linear':