Merge pull request #4777 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2023-12-01 00:00:17 -03:00 committed by GitHub
commit 96df4f10b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 2 deletions

View File

@ -87,8 +87,8 @@
" !pip uninstall -y flash_attn\n", " !pip uninstall -y flash_attn\n",
"\n", "\n",
"# Parameters\n", "# Parameters\n",
"model_url = \"https://huggingface.co/turboderp/Mistral-7B-instruct-exl2\" #@param {type:\"string\"}\n", "model_url = \"https://huggingface.co/TheBloke/MythoMax-L2-13B-GPTQ\" #@param {type:\"string\"}\n",
"branch = \"4.0bpw\" #@param {type:\"string\"}\n", "branch = \"gptq-4bit-32g-actorder_True\" #@param {type:\"string\"}\n",
"command_line_flags = \"--n-gpu-layers 128 --load-in-4bit --use_double_quant\" #@param {type:\"string\"}\n", "command_line_flags = \"--n-gpu-layers 128 --load-in-4bit --use_double_quant\" #@param {type:\"string\"}\n",
"api = False #@param {type:\"boolean\"}\n", "api = False #@param {type:\"boolean\"}\n",
"\n", "\n",

View File

@ -105,6 +105,7 @@ class LlamaCppModel:
return self.model.detokenize(ids).decode('utf-8') return self.model.detokenize(ids).decode('utf-8')
def get_logits(self, tokens): def get_logits(self, tokens):
self.model.reset()
self.model.eval(tokens) self.model.eval(tokens)
logits = self.model._scores logits = self.model._scores
logits = np.expand_dims(logits, 0) # batch dim is expected logits = np.expand_dims(logits, 0) # batch dim is expected

View File

@ -102,6 +102,7 @@ def load_model(model_name, loader=None):
elif loader in ['llama.cpp', 'llamacpp_HF', 'ctransformers']: elif loader in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
shared.settings['truncation_length'] = shared.args.n_ctx shared.settings['truncation_length'] = shared.args.n_ctx
logger.info(f"LOADER: {loader}")
logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}") logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}")
logger.info(f"INSTRUCTION TEMPLATE: {shared.settings['instruction_template']}") logger.info(f"INSTRUCTION TEMPLATE: {shared.settings['instruction_template']}")
logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.") logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")