2023-09-16 05:11:16 +02:00
|
|
|
from modules.text_generation import decode, encode
|
|
|
|
|
2023-07-12 20:33:25 +02:00
|
|
|
|
2023-07-11 23:50:08 +02:00
|
|
|
def token_count(prompt):
|
|
|
|
tokens = encode(prompt)[0]
|
|
|
|
return {
|
2023-11-08 04:05:36 +01:00
|
|
|
'length': len(tokens)
|
2023-07-11 23:50:08 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2023-11-08 04:05:36 +01:00
|
|
|
def token_encode(input):
|
2023-07-11 23:50:08 +02:00
|
|
|
tokens = encode(input)[0]
|
2023-11-08 04:05:36 +01:00
|
|
|
if tokens.__class__.__name__ in ['Tensor', 'ndarray']:
|
|
|
|
tokens = tokens.tolist()
|
2023-07-11 23:50:08 +02:00
|
|
|
|
|
|
|
return {
|
2023-11-08 04:05:36 +01:00
|
|
|
'tokens': tokens,
|
|
|
|
'length': len(tokens),
|
2023-07-11 23:50:08 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2023-11-08 04:05:36 +01:00
|
|
|
def token_decode(tokens):
|
|
|
|
output = decode(tokens)
|
2023-07-11 23:50:08 +02:00
|
|
|
return {
|
2023-11-08 04:05:36 +01:00
|
|
|
'text': output
|
2023-07-11 23:50:08 +02:00
|
|
|
}
|