271 lines
9.8 KiB
Python
271 lines
9.8 KiB
Python
"""HTTP 服务器生命周期管理。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import socket
|
||
import time
|
||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||
from pathlib import Path
|
||
|
||
from aide.core import output
|
||
from aide.core.config import ConfigManager
|
||
from aide.decide.errors import DecideError
|
||
from aide.decide.handlers import DecideHandlers, Response
|
||
from aide.decide.storage import DecideStorage
|
||
|
||
|
||
class DecideHTTPServer(HTTPServer):
|
||
"""附带处理器实例的 HTTPServer。"""
|
||
|
||
def __init__(self, server_address, RequestHandlerClass, handlers: DecideHandlers):
|
||
self.handlers = handlers
|
||
super().__init__(server_address, RequestHandlerClass)
|
||
|
||
|
||
class DecideServer:
|
||
"""启动、监听与关闭 HTTP 服务。"""
|
||
|
||
def __init__(self, root, storage: DecideStorage):
|
||
self.root = root
|
||
self.storage = storage
|
||
self.port = 3721
|
||
self.timeout = 0
|
||
self.bind = "127.0.0.1"
|
||
self.url = ""
|
||
# web 资源位于 aide 包目录下,而非项目根目录
|
||
self.web_dir = Path(__file__).parent / "web"
|
||
self.should_close = False
|
||
self.close_reason: str | None = None
|
||
self.httpd: DecideHTTPServer | None = None
|
||
|
||
def start(self) -> bool:
|
||
try:
|
||
config = ConfigManager(self.root).load_config()
|
||
start_port = _get_int(config, "decide", "port", default=3721)
|
||
self.timeout = _get_int(config, "decide", "timeout", default=0)
|
||
self.bind = _get_str(config, "decide", "bind", default="127.0.0.1")
|
||
self.url = _get_str(config, "decide", "url", default="")
|
||
end_port = start_port + 9
|
||
available = self._find_available_port(start_port)
|
||
if available is None:
|
||
output.err(f"无法启动服务: 端口 {start_port}-{end_port} 均被占用")
|
||
output.info("建议: 关闭占用端口的程序,或在配置中指定其他端口")
|
||
return False
|
||
self.port = available
|
||
|
||
handlers = DecideHandlers(
|
||
storage=self.storage,
|
||
web_dir=self.web_dir,
|
||
stop_callback=self.stop,
|
||
)
|
||
RequestHandler = self._build_request_handler(handlers)
|
||
self.httpd = DecideHTTPServer((self.bind, self.port), RequestHandler, handlers)
|
||
self.httpd.timeout = 1.0
|
||
|
||
# 生成访问地址:优先使用自定义 url,否则自动生成
|
||
if self.url:
|
||
access_url = self.url
|
||
else:
|
||
access_url = f"http://localhost:{self.port}"
|
||
|
||
output.info("Web 服务已启动")
|
||
output.info(f"请访问: {access_url}")
|
||
output.info("等待用户完成决策...")
|
||
|
||
self._serve_forever()
|
||
|
||
if self.close_reason == "completed":
|
||
output.ok("决策已完成")
|
||
return True
|
||
if self.close_reason == "timeout":
|
||
output.warn("服务超时,已自动关闭")
|
||
return True
|
||
if self.close_reason == "interrupted":
|
||
output.warn("服务已中断")
|
||
return True
|
||
return True
|
||
except DecideError as exc:
|
||
output.err(str(exc))
|
||
return False
|
||
|
||
def stop(self, reason: str) -> None:
|
||
if self.should_close:
|
||
return
|
||
self.should_close = True
|
||
self.close_reason = reason
|
||
|
||
def start_daemon(self, pid: int) -> bool:
|
||
"""作为后台进程启动服务(由 daemon.py 调用)。
|
||
|
||
与 start() 的区别:
|
||
- 不输出到 stdout(后台运行)
|
||
- 保存服务信息到 server.json
|
||
- 退出时清理 server.json
|
||
"""
|
||
try:
|
||
config = ConfigManager(self.root).load_config()
|
||
start_port = _get_int(config, "decide", "port", default=3721)
|
||
self.timeout = _get_int(config, "decide", "timeout", default=0)
|
||
self.bind = _get_str(config, "decide", "bind", default="127.0.0.1")
|
||
self.url = _get_str(config, "decide", "url", default="")
|
||
|
||
available = self._find_available_port(start_port)
|
||
if available is None:
|
||
return False
|
||
self.port = available
|
||
|
||
handlers = DecideHandlers(
|
||
storage=self.storage,
|
||
web_dir=self.web_dir,
|
||
stop_callback=self.stop,
|
||
)
|
||
RequestHandler = self._build_request_handler(handlers)
|
||
self.httpd = DecideHTTPServer((self.bind, self.port), RequestHandler, handlers)
|
||
self.httpd.timeout = 1.0
|
||
|
||
# 生成访问地址
|
||
access_url = self.url if self.url else f"http://localhost:{self.port}"
|
||
|
||
# 保存服务信息(供 CLI 读取)
|
||
self.storage.save_server_info(pid, self.port, access_url)
|
||
|
||
# 阻塞等待用户操作
|
||
self._serve_forever()
|
||
|
||
# 清理服务信息
|
||
self.storage.clear_server_info()
|
||
return True
|
||
except Exception:
|
||
self.storage.clear_server_info()
|
||
return False
|
||
|
||
def _find_available_port(self, start: int) -> int | None:
|
||
attempts = 10
|
||
for offset in range(attempts):
|
||
port = start + offset
|
||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||
try:
|
||
sock.bind((self.bind, port))
|
||
return port
|
||
except OSError:
|
||
continue
|
||
return None
|
||
|
||
def _serve_forever(self) -> None:
|
||
if self.httpd is None:
|
||
return
|
||
deadline = None if self.timeout <= 0 else time.time() + self.timeout
|
||
|
||
try:
|
||
while not self.should_close:
|
||
if deadline is not None and time.time() >= deadline:
|
||
self.stop("timeout")
|
||
break
|
||
self.httpd.handle_request()
|
||
except KeyboardInterrupt:
|
||
self.stop("interrupted")
|
||
finally:
|
||
try:
|
||
self.httpd.server_close()
|
||
except Exception:
|
||
pass
|
||
|
||
def _build_request_handler(self, handlers: DecideHandlers):
|
||
server = self
|
||
|
||
class RequestHandler(BaseHTTPRequestHandler):
|
||
protocol_version = "HTTP/1.1"
|
||
|
||
def do_GET(self):
|
||
self._dispatch("GET")
|
||
|
||
def do_POST(self):
|
||
self._dispatch("POST")
|
||
|
||
def do_OPTIONS(self):
|
||
self._dispatch("OPTIONS")
|
||
|
||
def _dispatch(self, method: str) -> None:
|
||
length = self.headers.get("Content-Length")
|
||
body = b""
|
||
if method == "POST":
|
||
try:
|
||
content_length = int(length) if length else 0
|
||
except ValueError:
|
||
self._send_response(
|
||
(400, handlers._cors_headers({"Content-Type": "application/json; charset=utf-8"}), '{"error":"决策数据无效","detail":"无效的 Content-Length"}'.encode("utf-8"))
|
||
)
|
||
return
|
||
if content_length > 1024 * 1024:
|
||
self._send_response(
|
||
(
|
||
413,
|
||
handlers._cors_headers({"Content-Type": "application/json; charset=utf-8"}),
|
||
'{"error":"请求体过大","detail":"单次提交限制 1MB"}'.encode("utf-8"),
|
||
)
|
||
)
|
||
return
|
||
body = self.rfile.read(content_length)
|
||
|
||
try:
|
||
response = handlers.handle(method, self.path, body)
|
||
except Exception as exc: # pragma: no cover - 兜底防御
|
||
payload = (
|
||
'{"error":"服务器内部错误","detail":'
|
||
+ json.dumps(str(exc), ensure_ascii=False)
|
||
+ "}"
|
||
).encode("utf-8")
|
||
response = (
|
||
500,
|
||
handlers._cors_headers({"Content-Type": "application/json; charset=utf-8"}),
|
||
payload,
|
||
)
|
||
self._send_response(response)
|
||
|
||
if server.should_close:
|
||
# 已由 handlers 设置关闭标志,等待当前请求结束
|
||
pass
|
||
|
||
def log_message(self, format: str, *args) -> None: # noqa: A003
|
||
# 静默日志,避免干扰 CLI 输出
|
||
return
|
||
|
||
def _send_response(self, response: Response) -> None:
|
||
status, headers, body = response
|
||
self.send_response(status)
|
||
for key, value in headers.items():
|
||
self.send_header(key, value)
|
||
self.send_header("Content-Length", str(len(body)))
|
||
self.end_headers()
|
||
if body:
|
||
self.wfile.write(body)
|
||
|
||
return RequestHandler
|
||
|
||
|
||
def _get_int(config: dict, section: str, key: str, default: int) -> int:
|
||
try:
|
||
section_data = config.get(section, {}) if isinstance(config, dict) else {}
|
||
value = section_data.get(key, default)
|
||
if isinstance(value, bool):
|
||
return default
|
||
if isinstance(value, (int, float)):
|
||
as_int = int(value)
|
||
return as_int if as_int >= 0 else default
|
||
except Exception:
|
||
return default
|
||
return default
|
||
|
||
|
||
def _get_str(config: dict, section: str, key: str, default: str) -> str:
|
||
try:
|
||
section_data = config.get(section, {}) if isinstance(config, dict) else {}
|
||
value = section_data.get(key, default)
|
||
if isinstance(value, str):
|
||
return value
|
||
except Exception:
|
||
return default
|
||
return default
|