added safe exit, bulk call, polling delay for broker
This commit is contained in:
@@ -10,13 +10,14 @@ from ._callSpec import _CallPacket, _ClientPacket
|
|||||||
__all__ = ["runBroker"]
|
__all__ = ["runBroker"]
|
||||||
|
|
||||||
class _Broker:
|
class _Broker:
|
||||||
def __init__(self) -> None:
|
def __init__(self, pollingDelay=0.5) -> None:
|
||||||
self.router = fastapi.APIRouter()
|
self.router = fastapi.APIRouter()
|
||||||
self.router.add_api_websocket_route("/reg", self.registerRunner)
|
self.router.add_api_websocket_route("/reg", self.registerRunner)
|
||||||
self.router.add_api_route("/cliReq", self.clientRequest, methods=["POST"])
|
self.router.add_api_route("/cliReq", self.clientRequest, methods=["POST"])
|
||||||
self.taskQueue = asyncio.Queue()
|
self.taskQueue = asyncio.Queue()
|
||||||
self.runnerCount=0
|
self.runnerCount=0
|
||||||
self.returnDict = {}
|
self.returnDict = {}
|
||||||
|
self.pollingDelay = pollingDelay
|
||||||
|
|
||||||
|
|
||||||
async def registerRunner(self, wsConnection: fastapi.WebSocket):
|
async def registerRunner(self, wsConnection: fastapi.WebSocket):
|
||||||
@@ -28,28 +29,23 @@ class _Broker:
|
|||||||
self.runnerCount+=1
|
self.runnerCount+=1
|
||||||
while True:
|
while True:
|
||||||
reqID, data = await self.taskQueue.get()
|
reqID, data = await self.taskQueue.get()
|
||||||
# await asyncio.sleep(1)
|
|
||||||
await wsConnection.send_bytes(pkl.dumps(data))
|
await wsConnection.send_bytes(pkl.dumps(data))
|
||||||
retValue = await wsConnection.receive()
|
retValue = await wsConnection.receive()
|
||||||
# print(retValue)
|
|
||||||
self.returnDict[reqID] = retValue["bytes"]
|
self.returnDict[reqID] = retValue["bytes"]
|
||||||
# print(retValue["bytes"])
|
|
||||||
print(f"Tasks left: {self.taskQueue.qsize()}")
|
print(f"Tasks left: {self.taskQueue.qsize()}")
|
||||||
|
|
||||||
async def clientRequest(self, data:_ClientPacket):
|
async def clientRequest(self, data:_ClientPacket):
|
||||||
# print(data)
|
|
||||||
reqID = uuid.uuid4().hex
|
reqID = uuid.uuid4().hex
|
||||||
callPacket = pkl.loads(base64.b64decode(data.data))
|
callPacket = pkl.loads(base64.b64decode(data.data))
|
||||||
await self.taskQueue.put((reqID, callPacket))
|
await self.taskQueue.put((reqID, callPacket))
|
||||||
# print(self.taskQueue.qsize)
|
|
||||||
while reqID not in self.returnDict:
|
while reqID not in self.returnDict:
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(self.pollingDelay)
|
||||||
await asyncio.sleep(1)
|
# await asyncio.sleep(1)
|
||||||
returnValue = self.returnDict[reqID]
|
returnValue = self.returnDict[reqID]
|
||||||
return returnValue
|
return returnValue
|
||||||
|
|
||||||
def runBroker(host, port):
|
def runBroker(host, port, pollingDelay=0.1):
|
||||||
br = _Broker()
|
br = _Broker(pollingDelay=pollingDelay)
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
app.include_router(br.router)
|
app.include_router(br.router)
|
||||||
serverConf = Config(app = app, host=host, port=port, log_level=LOG_LEVELS["warning"], ws_ping_interval=10, ws_ping_timeout=None)
|
serverConf = Config(app = app, host=host, port=port, log_level=LOG_LEVELS["warning"], ws_ping_interval=10, ws_ping_timeout=None)
|
||||||
|
|||||||
@@ -1,18 +1,43 @@
|
|||||||
|
import time
|
||||||
import requests as req
|
import requests as req
|
||||||
from ._callSpec import _CallPacket
|
from ._callSpec import _CallPacket
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
import base64
|
import base64
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
__all__ = ["Client"]
|
__all__ = ["Client"]
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
def __init__(self, host, port) -> None:
|
def __init__(self, host, port) -> None:
|
||||||
self._url = f"http://{host}:{port}/cliReq"
|
self._url = f"http://{host}:{port}/cliReq"
|
||||||
|
self.tasks = []
|
||||||
|
|
||||||
def rCall(self, function, **kwargs):
|
def singleCall(self, function, **kwargs):
|
||||||
callPacket = _CallPacket(procedure=function, data=kwargs)
|
callPacket = _CallPacket(procedure=function, data=kwargs)
|
||||||
payload = {"data": base64.b64encode(pkl.dumps(callPacket)).decode("utf-8")}
|
payload = {"data": base64.b64encode(pkl.dumps(callPacket)).decode("utf-8")}
|
||||||
resp = req.post(self._url, json=payload)
|
resp = req.post(self._url, json=payload)
|
||||||
# print(resp.status_code)
|
|
||||||
# print(resp.text)
|
|
||||||
return pkl.loads(base64.b64decode(resp.text))
|
return pkl.loads(base64.b64decode(resp.text))
|
||||||
|
|
||||||
|
def addCall(self, function, **kwargs):
|
||||||
|
self.tasks.append((function, kwargs))
|
||||||
|
print(f"Total in Queue: {len(self.tasks)}")
|
||||||
|
|
||||||
|
def runAllCalls(self, callDelay=0.01):
|
||||||
|
if len(self.tasks) == 0:
|
||||||
|
return []
|
||||||
|
self.returnValues = [0]*len(self.tasks)
|
||||||
|
self.done = [0] * len(self.tasks)
|
||||||
|
for callIDX in range(len(self.tasks)):
|
||||||
|
t = Thread(target=self._threadWorker, args=[callIDX, self.tasks[callIDX]])
|
||||||
|
t.start()
|
||||||
|
time.sleep(callDelay)
|
||||||
|
while not all(self.done):
|
||||||
|
time.sleep(1)
|
||||||
|
self.tasks = []
|
||||||
|
return self.returnValues
|
||||||
|
|
||||||
|
def _threadWorker(self, callIDX, payload):
|
||||||
|
print(callIDX, payload)
|
||||||
|
ret = self.singleCall(function=payload[0], **payload[1])
|
||||||
|
self.returnValues[callIDX] = ret
|
||||||
|
self.done[callIDX] =1
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
import base64
|
import base64
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
# from fastapi import WebSocketException
|
||||||
from websockets.asyncio import client as WSC
|
from websockets.asyncio import client as WSC
|
||||||
|
from websockets.exceptions import WebSocketException
|
||||||
import asyncio
|
import asyncio
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
from ._callSpec import _CallPacket
|
from ._callSpec import _CallPacket
|
||||||
|
|
||||||
__all__ = ["startRunner"]
|
__all__ = ["startRunner"]
|
||||||
|
|
||||||
async def _test(funcMap: Dict[str, Any], url):
|
async def _send(funcMap: Dict[str, Any], url):
|
||||||
counter=0
|
counter=0
|
||||||
async with WSC.connect(url, open_timeout=None, ping_interval=10, ping_timeout=None ) as w:
|
async with WSC.connect(url, open_timeout=None, ping_interval=10, ping_timeout=None ) as w:
|
||||||
|
try:
|
||||||
id = await w.recv()
|
id = await w.recv()
|
||||||
id = int(id)
|
id = int(id)
|
||||||
print(f"Starting Runner, ID: {id}")
|
print(f"Starting Runner, ID: {id}")
|
||||||
@@ -18,9 +21,12 @@ async def _test(funcMap: Dict[str, Any], url):
|
|||||||
counter+=1
|
counter+=1
|
||||||
packetBytes=await w.recv()
|
packetBytes=await w.recv()
|
||||||
callPk:_CallPacket = pkl.loads(packetBytes)
|
callPk:_CallPacket = pkl.loads(packetBytes)
|
||||||
print("-"*50 + f"\nRunning: {callPk.procedure}\nArgs: {callPk.data}\n" + "-"*50)
|
print("-"*50 + f"\nRunning: {callPk.procedure}\nArgs: {callPk.data}\nCounter: {counter}\n" + "-"*50)
|
||||||
funcOutput = funcMap[callPk.procedure](**callPk.data)
|
funcOutput = funcMap[callPk.procedure](**callPk.data)
|
||||||
await w.send(base64.b64encode(pkl.dumps(funcOutput)))
|
await w.send(base64.b64encode(pkl.dumps(funcOutput)))
|
||||||
|
except WebSocketException as e:
|
||||||
|
print(f"Closing Conncetion with Broker, total call count: {counter}")
|
||||||
|
await w.close()
|
||||||
|
|
||||||
def startRunner(funcMapping, host, port):
|
def startRunner(funcMapping, host, port):
|
||||||
asyncio.run(_test(funcMapping, f"ws://{host}:{port}/reg"))
|
asyncio.run(_send(funcMapping, f"ws://{host}:{port}/reg"))
|
||||||
|
|||||||
Reference in New Issue
Block a user