added safe exit, bulk call, polling delay for broker
This commit is contained in:
@@ -10,13 +10,14 @@ from ._callSpec import _CallPacket, _ClientPacket
|
||||
__all__ = ["runBroker"]
|
||||
|
||||
class _Broker:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, pollingDelay=0.5) -> None:
|
||||
self.router = fastapi.APIRouter()
|
||||
self.router.add_api_websocket_route("/reg", self.registerRunner)
|
||||
self.router.add_api_route("/cliReq", self.clientRequest, methods=["POST"])
|
||||
self.taskQueue = asyncio.Queue()
|
||||
self.runnerCount=0
|
||||
self.returnDict = {}
|
||||
self.pollingDelay = pollingDelay
|
||||
|
||||
|
||||
async def registerRunner(self, wsConnection: fastapi.WebSocket):
|
||||
@@ -28,28 +29,23 @@ class _Broker:
|
||||
self.runnerCount+=1
|
||||
while True:
|
||||
reqID, data = await self.taskQueue.get()
|
||||
# await asyncio.sleep(1)
|
||||
await wsConnection.send_bytes(pkl.dumps(data))
|
||||
retValue = await wsConnection.receive()
|
||||
# print(retValue)
|
||||
self.returnDict[reqID] = retValue["bytes"]
|
||||
# print(retValue["bytes"])
|
||||
print(f"Tasks left: {self.taskQueue.qsize()}")
|
||||
|
||||
async def clientRequest(self, data:_ClientPacket):
|
||||
# print(data)
|
||||
reqID = uuid.uuid4().hex
|
||||
callPacket = pkl.loads(base64.b64decode(data.data))
|
||||
await self.taskQueue.put((reqID, callPacket))
|
||||
# print(self.taskQueue.qsize)
|
||||
while reqID not in self.returnDict:
|
||||
await asyncio.sleep(0.5)
|
||||
await asyncio.sleep(1)
|
||||
await asyncio.sleep(self.pollingDelay)
|
||||
# await asyncio.sleep(1)
|
||||
returnValue = self.returnDict[reqID]
|
||||
return returnValue
|
||||
|
||||
def runBroker(host, port):
|
||||
br = _Broker()
|
||||
def runBroker(host, port, pollingDelay=0.1):
|
||||
br = _Broker(pollingDelay=pollingDelay)
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(br.router)
|
||||
serverConf = Config(app = app, host=host, port=port, log_level=LOG_LEVELS["warning"], ws_ping_interval=10, ws_ping_timeout=None)
|
||||
|
||||
@@ -1,18 +1,43 @@
|
||||
import time
|
||||
import requests as req
|
||||
from ._callSpec import _CallPacket
|
||||
import pickle as pkl
|
||||
import base64
|
||||
from threading import Thread
|
||||
|
||||
__all__ = ["Client"]
|
||||
|
||||
class Client:
|
||||
def __init__(self, host, port) -> None:
|
||||
self._url = f"http://{host}:{port}/cliReq"
|
||||
self.tasks = []
|
||||
|
||||
def rCall(self, function, **kwargs):
|
||||
def singleCall(self, function, **kwargs):
|
||||
callPacket = _CallPacket(procedure=function, data=kwargs)
|
||||
payload = {"data": base64.b64encode(pkl.dumps(callPacket)).decode("utf-8")}
|
||||
resp = req.post(self._url, json=payload)
|
||||
# print(resp.status_code)
|
||||
# print(resp.text)
|
||||
return pkl.loads(base64.b64decode(resp.text))
|
||||
|
||||
def addCall(self, function, **kwargs):
|
||||
self.tasks.append((function, kwargs))
|
||||
print(f"Total in Queue: {len(self.tasks)}")
|
||||
|
||||
def runAllCalls(self, callDelay=0.01):
|
||||
if len(self.tasks) == 0:
|
||||
return []
|
||||
self.returnValues = [0]*len(self.tasks)
|
||||
self.done = [0] * len(self.tasks)
|
||||
for callIDX in range(len(self.tasks)):
|
||||
t = Thread(target=self._threadWorker, args=[callIDX, self.tasks[callIDX]])
|
||||
t.start()
|
||||
time.sleep(callDelay)
|
||||
while not all(self.done):
|
||||
time.sleep(1)
|
||||
self.tasks = []
|
||||
return self.returnValues
|
||||
|
||||
def _threadWorker(self, callIDX, payload):
|
||||
print(callIDX, payload)
|
||||
ret = self.singleCall(function=payload[0], **payload[1])
|
||||
self.returnValues[callIDX] = ret
|
||||
self.done[callIDX] =1
|
||||
|
||||
@@ -1,26 +1,32 @@
|
||||
import base64
|
||||
from typing import Any, Dict
|
||||
# from fastapi import WebSocketException
|
||||
from websockets.asyncio import client as WSC
|
||||
from websockets.exceptions import WebSocketException
|
||||
import asyncio
|
||||
import pickle as pkl
|
||||
from ._callSpec import _CallPacket
|
||||
|
||||
__all__ = ["startRunner"]
|
||||
|
||||
async def _test(funcMap: Dict[str, Any], url):
|
||||
async def _send(funcMap: Dict[str, Any], url):
|
||||
counter=0
|
||||
async with WSC.connect(url, open_timeout=None, ping_interval=10, ping_timeout=None ) as w:
|
||||
id = await w.recv()
|
||||
id = int(id)
|
||||
print(f"Starting Runner, ID: {id}")
|
||||
await w.send(base64.b64encode(pkl.dumps({"methods":list(funcMap.keys())})).decode("utf-8"))
|
||||
while True:
|
||||
counter+=1
|
||||
packetBytes=await w.recv()
|
||||
callPk:_CallPacket = pkl.loads(packetBytes)
|
||||
print("-"*50 + f"\nRunning: {callPk.procedure}\nArgs: {callPk.data}\n" + "-"*50)
|
||||
funcOutput = funcMap[callPk.procedure](**callPk.data)
|
||||
await w.send(base64.b64encode(pkl.dumps(funcOutput)))
|
||||
try:
|
||||
id = await w.recv()
|
||||
id = int(id)
|
||||
print(f"Starting Runner, ID: {id}")
|
||||
await w.send(base64.b64encode(pkl.dumps({"methods":list(funcMap.keys())})).decode("utf-8"))
|
||||
while True:
|
||||
counter+=1
|
||||
packetBytes=await w.recv()
|
||||
callPk:_CallPacket = pkl.loads(packetBytes)
|
||||
print("-"*50 + f"\nRunning: {callPk.procedure}\nArgs: {callPk.data}\nCounter: {counter}\n" + "-"*50)
|
||||
funcOutput = funcMap[callPk.procedure](**callPk.data)
|
||||
await w.send(base64.b64encode(pkl.dumps(funcOutput)))
|
||||
except WebSocketException as e:
|
||||
print(f"Closing Conncetion with Broker, total call count: {counter}")
|
||||
await w.close()
|
||||
|
||||
def startRunner(funcMapping, host, port):
|
||||
asyncio.run(_test(funcMapping, f"ws://{host}:{port}/reg"))
|
||||
asyncio.run(_send(funcMapping, f"ws://{host}:{port}/reg"))
|
||||
|
||||
Reference in New Issue
Block a user