Print the softprompt metadata when it is loaded

This commit is contained in:
oobabooga 2023-02-19 01:48:23 -03:00 committed by GitHub
parent f79805f4a4
commit 8c9dd95d55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -173,7 +173,19 @@ def load_soft_prompt(name):
else: else:
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf: with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
zf.extract('tensor.npy') zf.extract('tensor.npy')
zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read())
print(f"\nLoading the softprompt \"{name}\".")
for field in j:
if field != 'name':
if type(j[field]) is list:
print(f"{field}: {', '.join(j[field])}")
else:
print(f"{field}: {j[field]}")
print()
tensor = np.load('tensor.npy') tensor = np.load('tensor.npy')
Path('tensor.npy').unlink()
Path('meta.json').unlink()
tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype) tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype)
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1])) tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
@ -187,6 +199,7 @@ def upload_soft_prompt(file):
zf.extract('meta.json') zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read()) j = json.loads(open('meta.json', 'r').read())
name = j['name'] name = j['name']
Path('meta.json').unlink()
with open(Path(f'softprompts/{name}.zip'), 'wb') as f: with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
f.write(file) f.write(file)