144 lines
5.3 KiB
Python
144 lines
5.3 KiB
Python
import asyncio
|
|
import base64
|
|
import pickle as pkl
|
|
import uuid
|
|
from threading import Lock, Thread
|
|
|
|
import fastapi
|
|
from uvicorn import Config, Server
|
|
from uvicorn.config import LOG_LEVELS
|
|
|
|
from .._callSpec import _ClientPacket
|
|
|
|
__all__ = ["startRouter"]
|
|
|
|
|
|
class _Router:
|
|
def __init__(
|
|
self, pollingDelay: float | int = 0.5, recycleOnError: bool = False
|
|
) -> None:
|
|
self.router = fastapi.APIRouter()
|
|
self.router.add_api_websocket_route("/reg", self.registerRunner)
|
|
self.router.add_api_websocket_route("/cliReq/{count}", self.multiClientRequest)
|
|
self.taskQueue = asyncio.Queue()
|
|
self.runnerCount = 0
|
|
self.returnDict = {}
|
|
self.doneDict = {}
|
|
self.pollingDelay = pollingDelay
|
|
self.recycleOnError = recycleOnError
|
|
|
|
async def registerRunner(self, wsConnection: fastapi.WebSocket):
|
|
"""
|
|
Method which queries an available task and sends the data to the attached runner.
|
|
"""
|
|
l = Lock()
|
|
l.acquire()
|
|
await wsConnection.accept()
|
|
await wsConnection.send_text(str(self.runnerCount))
|
|
methods = await wsConnection.receive()
|
|
methods = pkl.loads(base64.b64decode(methods["text"]))
|
|
print(
|
|
f"Runner Connected with ID: {self.runnerCount}, Methods: {methods['methods']}"
|
|
)
|
|
runnerID = self.runnerCount
|
|
self.runnerCount += 1
|
|
runnerCounter = 0
|
|
l.release()
|
|
try:
|
|
while True:
|
|
# add this id back into available id pools using a shared list. Implement custom error types/ use existing ones to define pre func call disconnect(runner crash/close before running current arguments) or post func call (crash due to calling current args). If pre, strictly add them back to task queue. else, add back conditionally.
|
|
reqID, data = await self.taskQueue.get()
|
|
runnerCounter += 1
|
|
print(f"Runr {runnerID} Counter: {runnerCounter}")
|
|
await wsConnection.send_bytes(pkl.dumps(data))
|
|
retValue = await wsConnection.receive()
|
|
if "bytes" not in retValue:
|
|
self.returnDict[reqID] = data
|
|
raise Exception(
|
|
f"Runner {runnerID} Crashed!! Check for function error logs on the Runner."
|
|
)
|
|
self.returnDict[reqID] = pkl.loads(base64.b64decode(retValue["bytes"]))
|
|
except Exception as e:
|
|
print(e)
|
|
if self.recycleOnError and "reqID" in locals() and "data" in locals():
|
|
await self.taskQueue.put((reqID, data))
|
|
print("Recycled this task")
|
|
# await wsConnection.close()
|
|
print(f"Runner {runnerID} Closed")
|
|
|
|
async def clientRequest(self, data: _ClientPacket):
|
|
"""
|
|
Method to handle single request, adds the task to queue and awaits for result.
|
|
To be deprecated for better task handling.
|
|
"""
|
|
reqID = uuid.uuid4().hex
|
|
callPacket = data
|
|
await self.taskQueue.put((reqID, callPacket))
|
|
while reqID not in self.returnDict:
|
|
await asyncio.sleep(self.pollingDelay)
|
|
returnValue = self.returnDict.pop(reqID)
|
|
return returnValue
|
|
|
|
async def multiClientRequest(self, wsConn: fastapi.WebSocket, count: int):
|
|
"""
|
|
Method accepts a task list and adds them to the queue.
|
|
Returns the results to client.
|
|
"""
|
|
await wsConn.accept()
|
|
softLimit = 50
|
|
await wsConn.send_text(str(count))
|
|
reqID = uuid.uuid4().hex
|
|
self.returnDict[reqID] = [0] * count
|
|
self.doneDict[reqID] = [0] * count
|
|
# print(f"Received {count} tasks")
|
|
taskBytes = await wsConn.receive_bytes()
|
|
taskPackets = pkl.loads(taskBytes)
|
|
softLimitItr = 0
|
|
for task in range(len(taskPackets)):
|
|
while (task > (softLimitItr + softLimit)) and not self.doneDict[reqID][
|
|
softLimitItr
|
|
] == 1:
|
|
await asyncio.sleep(1)
|
|
if self.doneDict[reqID][softLimitItr] == 1:
|
|
softLimitItr += 1
|
|
t = Thread(target=self._worker, args=(reqID, task, taskPackets[task]))
|
|
t.daemon = True
|
|
t.start()
|
|
while not all(self.doneDict[reqID]):
|
|
await asyncio.sleep(1)
|
|
await wsConn.send_bytes(pkl.dumps(self.returnDict[reqID]))
|
|
self.returnDict.pop(reqID)
|
|
|
|
def _worker(self, id, idx, data: _ClientPacket):
|
|
"""
|
|
Thread worker to handle one task.
|
|
To be depricated for better task handling.
|
|
"""
|
|
retVal = asyncio.run(self.clientRequest(data))
|
|
self.returnDict[id][idx] = retVal
|
|
self.doneDict[id][idx] = 1
|
|
return
|
|
|
|
|
|
def startRouter(
|
|
host: str, port: str | int, pollingDelay: float | int = 0.1, logLevel: int = 3
|
|
):
|
|
"""
|
|
Main function to start the router system.
|
|
"""
|
|
br = _Router(pollingDelay=pollingDelay)
|
|
app = fastapi.FastAPI()
|
|
app.include_router(br.router)
|
|
level = list(LOG_LEVELS.keys())[logLevel]
|
|
serverConf = Config(
|
|
app=app,
|
|
host=host,
|
|
port=int(port),
|
|
log_level=LOG_LEVELS[level],
|
|
ws_ping_interval=10,
|
|
ws_ping_timeout=None,
|
|
ws_max_size=1024 * 1024 * 1024,
|
|
)
|
|
server = Server(config=serverConf)
|
|
server.run()
|