Files
harmoney/harmoney/router/router.py
2026-03-31 10:09:00 +05:30

144 lines
5.3 KiB
Python

import asyncio
import base64
import pickle as pkl
import uuid
from threading import Lock, Thread
import fastapi
from uvicorn import Config, Server
from uvicorn.config import LOG_LEVELS
from .._callSpec import _ClientPacket
__all__ = ["startRouter"]
class _Router:
def __init__(
self, pollingDelay: float | int = 0.5, recycleOnError: bool = False
) -> None:
self.router = fastapi.APIRouter()
self.router.add_api_websocket_route("/reg", self.registerRunner)
self.router.add_api_websocket_route("/cliReq/{count}", self.multiClientRequest)
self.taskQueue = asyncio.Queue()
self.runnerCount = 0
self.returnDict = {}
self.doneDict = {}
self.pollingDelay = pollingDelay
self.recycleOnError = recycleOnError
async def registerRunner(self, wsConnection: fastapi.WebSocket):
"""
Method which queries an available task and sends the data to the attached runner.
"""
l = Lock()
l.acquire()
await wsConnection.accept()
await wsConnection.send_text(str(self.runnerCount))
methods = await wsConnection.receive()
methods = pkl.loads(base64.b64decode(methods["text"]))
print(
f"Runner Connected with ID: {self.runnerCount}, Methods: {methods['methods']}"
)
runnerID = self.runnerCount
self.runnerCount += 1
runnerCounter = 0
l.release()
try:
while True:
# add this id back into available id pools using a shared list. Implement custom error types/ use existing ones to define pre func call disconnect(runner crash/close before running current arguments) or post func call (crash due to calling current args). If pre, strictly add them back to task queue. else, add back conditionally.
reqID, data = await self.taskQueue.get()
runnerCounter += 1
print(f"Runr {runnerID} Counter: {runnerCounter}")
await wsConnection.send_bytes(pkl.dumps(data))
retValue = await wsConnection.receive()
if "bytes" not in retValue:
self.returnDict[reqID] = data
raise Exception(
f"Runner {runnerID} Crashed!! Check for function error logs on the Runner."
)
self.returnDict[reqID] = pkl.loads(base64.b64decode(retValue["bytes"]))
except Exception as e:
print(e)
if self.recycleOnError and "reqID" in locals() and "data" in locals():
await self.taskQueue.put((reqID, data))
print("Recycled this task")
# await wsConnection.close()
print(f"Runner {runnerID} Closed")
async def clientRequest(self, data: _ClientPacket):
"""
Method to handle single request, adds the task to queue and awaits for result.
To be deprecated for better task handling.
"""
reqID = uuid.uuid4().hex
callPacket = data
await self.taskQueue.put((reqID, callPacket))
while reqID not in self.returnDict:
await asyncio.sleep(self.pollingDelay)
returnValue = self.returnDict.pop(reqID)
return returnValue
async def multiClientRequest(self, wsConn: fastapi.WebSocket, count: int):
"""
Method accepts a task list and adds them to the queue.
Returns the results to client.
"""
await wsConn.accept()
softLimit = 50
await wsConn.send_text(str(count))
reqID = uuid.uuid4().hex
self.returnDict[reqID] = [0] * count
self.doneDict[reqID] = [0] * count
# print(f"Received {count} tasks")
taskBytes = await wsConn.receive_bytes()
taskPackets = pkl.loads(taskBytes)
softLimitItr = 0
for task in range(len(taskPackets)):
while (task > (softLimitItr + softLimit)) and not self.doneDict[reqID][
softLimitItr
] == 1:
await asyncio.sleep(1)
if self.doneDict[reqID][softLimitItr] == 1:
softLimitItr += 1
t = Thread(target=self._worker, args=(reqID, task, taskPackets[task]))
t.daemon = True
t.start()
while not all(self.doneDict[reqID]):
await asyncio.sleep(1)
await wsConn.send_bytes(pkl.dumps(self.returnDict[reqID]))
self.returnDict.pop(reqID)
def _worker(self, id, idx, data: _ClientPacket):
"""
Thread worker to handle one task.
To be depricated for better task handling.
"""
retVal = asyncio.run(self.clientRequest(data))
self.returnDict[id][idx] = retVal
self.doneDict[id][idx] = 1
return
def startRouter(
host: str, port: str | int, pollingDelay: float | int = 0.1, logLevel: int = 3
):
"""
Main function to start the router system.
"""
br = _Router(pollingDelay=pollingDelay)
app = fastapi.FastAPI()
app.include_router(br.router)
level = list(LOG_LEVELS.keys())[logLevel]
serverConf = Config(
app=app,
host=host,
port=int(port),
log_level=LOG_LEVELS[level],
ws_ping_interval=10,
ws_ping_timeout=None,
ws_max_size=1024 * 1024 * 1024,
)
server = Server(config=serverConf)
server.run()