diff --git a/api-example-stream.py b/api-example-stream.py index 0d93b4b6..b7846ab4 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -5,77 +5,77 @@ import json import asyncio def random_hash(): - letters = string.ascii_lowercase + string.digits - return ''.join(random.choice(letters) for i in range(9)) + letters = string.ascii_lowercase + string.digits + return ''.join(random.choice(letters) for i in range(9)) async def run(context): - server = "127.0.0.1" - params = { - 'max_new_tokens': 200, - 'do_sample': True, - 'temperature': 0.5, - 'top_p': 0.9, - 'typical_p': 1, - 'repetition_penalty': 1.05, - 'top_k': 0, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'length_penalty': 1, - 'early_stopping': False, - } - session = random_hash() + server = "127.0.0.1" + params = { + 'max_new_tokens': 200, + 'do_sample': True, + 'temperature': 0.5, + 'top_p': 0.9, + 'typical_p': 1, + 'repetition_penalty': 1.05, + 'top_k': 0, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + } + session = random_hash() - async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: - while content := json.loads(await websocket.recv()): - #Python3.10 syntax, replace with if elif on older - match content["msg"]: - case "send_hash": - await websocket.send(json.dumps({ - "session_hash": session, - "fn_index": 7 - })) - case "estimation": - pass - case "send_data": - await websocket.send(json.dumps({ - "session_hash": session, - "fn_index": 7, - "data": [ - context, - params['max_new_tokens'], - params['do_sample'], - params['temperature'], - params['top_p'], - params['typical_p'], - params['repetition_penalty'], - params['top_k'], - params['min_length'], - params['no_repeat_ngram_size'], - params['num_beams'], - params['penalty_alpha'], - params['length_penalty'], - params['early_stopping'], - ] - })) - case "process_starts": - pass - case "process_generating" | "process_completed": - yield content["output"]["data"][0] - # You can search for your desired end indicator and - # stop generation by closing the websocket here - if (content["msg"] == "process_completed"): - break + async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: + while content := json.loads(await websocket.recv()): + #Python3.10 syntax, replace with if elif on older + match content["msg"]: + case "send_hash": + await websocket.send(json.dumps({ + "session_hash": session, + "fn_index": 7 + })) + case "estimation": + pass + case "send_data": + await websocket.send(json.dumps({ + "session_hash": session, + "fn_index": 7, + "data": [ + context, + params['max_new_tokens'], + params['do_sample'], + params['temperature'], + params['top_p'], + params['typical_p'], + params['repetition_penalty'], + params['top_k'], + params['min_length'], + params['no_repeat_ngram_size'], + params['num_beams'], + params['penalty_alpha'], + params['length_penalty'], + params['early_stopping'], + ] + })) + case "process_starts": + pass + case "process_generating" | "process_completed": + yield content["output"]["data"][0] + # You can search for your desired end indicator and + # stop generation by closing the websocket here + if (content["msg"] == "process_completed"): + break prompt = "What I would like to say is the following: " async def get_result(): - async for response in run(prompt): - # Print intermediate steps - print(response) + async for response in run(prompt): + # Print intermediate steps + print(response) - # Print final result - print(response) + # Print final result + print(response) -asyncio.run(get_result()) \ No newline at end of file +asyncio.run(get_result())