66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257 | class ElizabethService(Service):
"""ElizabethService, Ariadne 的直接后端"""
id = "elizabeth.service"
supported_interface_types = {ConnectionInterface}
http_interface: AiohttpClientInterface
connections: Dict[int, ConnectionMixin[U_Info]]
broadcast: Broadcast
def __init__(self) -> None:
"""初始化 ElizabethService"""
import creart
self.connections = {}
self.broadcast = creart.it(Broadcast)
if ContextDispatcher not in self.broadcast.prelude_dispatchers:
self.broadcast.prelude_dispatchers.append(ContextDispatcher)
if LaunartInterfaceDispatcher not in self.broadcast.prelude_dispatchers:
self.broadcast.prelude_dispatchers.append(LaunartInterfaceDispatcher)
if NoneDispatcher not in self.broadcast.finale_dispatchers:
self.broadcast.finale_dispatchers.append(NoneDispatcher)
super().__init__()
@staticmethod
def base_telemetry() -> None:
"""执行基础遥测检查"""
output: List[str] = [""]
dist_map: Dict[str, str] = get_dist_map()
output.extend(
" ".join(
[
f"[blue]{name}[/]:" if name.startswith("graiax-") else f"[magenta]{name}[/]:",
f"[green]{version}[/]",
]
)
for name, version in dist_map.items()
)
output.sort()
output.insert(0, f"[cyan]{ARIADNE_ASCII_LOGO}[/]")
rich_output = "\n".join(output)
logger.opt(colors=True).info(
rich_output.replace("[", "<").replace("]", ">"), alt=rich_output, highlighter=None
)
@staticmethod
async def check_update() -> None:
"""执行更新检查"""
output: List[str] = []
dist_map: Dict[str, str] = get_dist_map()
async with ClientSession() as session:
await asyncio.gather(
*(check_update(session, name, version, output) for name, version in dist_map.items())
)
output.sort()
if output:
output = (
["", "[bold]", f"[red]{len(output)}[/] [yellow]update(s) available:[/]"] + output + ["[/]"]
)
rich_output = "\n".join(output)
logger.opt(colors=True).warning(
rich_output.replace("[", "<").replace("]", ">"), alt=rich_output, highlighter=None
)
else:
logger.opt(colors=True).success("All dependencies up to date!", style="green")
def add_infos(self, infos: Iterable[U_Info]) -> Tuple[List[ConnectionMixin], int]:
"""通过传入的 Info 对象创建 Connection"""
infos = list(infos)
if not infos:
raise AriadneConfigurationError("No configs provided")
account: int = infos[0].account
if account in self.connections:
raise AriadneConfigurationError(f"Account {account} already exists")
if len({i.account for i in infos}) != 1:
raise AriadneConfigurationError("All configs must be for the same account")
infos.sort(key=lambda x: isinstance(x, HttpClientInfo))
# make sure the http client is the last one
conns: List[ConnectionMixin] = [self.add_info(conf) for conf in infos]
return conns, account
def add_info(self, config: U_Info) -> ConnectionMixin:
"""添加单个 Info"""
account: int = config.account
connection = CONFIG_MAP[config.__class__](config)
if account not in self.connections:
self.connections[account] = connection
elif isinstance(connection, HttpClientConnection):
upstream_conn = self.connections[account]
if upstream_conn.fallback:
raise ValueError(f"{upstream_conn} already has fallback connection")
connection.status = upstream_conn.status
connection.is_hook = True
upstream_conn.fallback = connection
else:
raise ValueError(f"Connection {self.connections[account]} conflicts with {connection}")
return connection
async def launch(self, mgr: Launart):
"""Launart 启动点"""
from .app import Ariadne
from .context import enter_context
from .event.lifecycle import AccountLaunch, AccountShutdown, ApplicationLaunch, ApplicationShutdown
self.base_telemetry()
async with self.stage("preparing"):
self.http_interface = mgr.get_interface(AiohttpClientInterface)
if "default_account" in Ariadne.options:
app = Ariadne.current()
with enter_context(app=app):
self.broadcast.postEvent(ApplicationLaunch(app))
for conn in self.connections.values():
app = Ariadne.current(conn.info.account)
def _disconnect_cb():
from graia.ariadne.event.lifecycle import AccountConnectionFail
self.broadcast.postEvent(AccountConnectionFail(app))
conn._connection_fail = _disconnect_cb
with enter_context(app=app):
self.broadcast.postEvent(AccountLaunch(app))
async with self.stage("cleanup"):
logger.info("Elizabeth Service cleaning up...", style="dark_orange")
if "default_account" in Ariadne.options:
app = Ariadne.current()
if app.connection.status.available:
with enter_context(app=app):
await self.broadcast.postEvent(ApplicationShutdown(app))
for conn in self.connections.values():
if conn.status.available:
app = Ariadne.current(conn.info.account)
with enter_context(app=app):
await self.broadcast.postEvent(AccountShutdown(app))
for task in asyncio.all_tasks():
if task.done():
continue
coro: Coroutine = task.get_coro() # type: ignore
if coro.__qualname__ == "Broadcast.Executor":
task.cancel()
logger.debug(f"Cancelled {task.get_name()} (Broadcast.Executor)")
logger.info("Checking for updates...", alt="[cyan]Checking for updates...[/]")
await self.check_update()
@property
def client_session(self) -> ClientSession:
"""获取 aiohttp 的 ClientSession
Returns:
ClientSession: ClientSession 对象
"""
return self.http_interface.service.session
@property
def required(self):
dependencies = {AiohttpClientInterface}
for conn in self.connections.values():
dependencies |= conn.dependencies
dependencies.add(conn.id)
return dependencies
@property
def stages(self):
return {"preparing", "cleanup"}
@property
def loop(self) -> asyncio.AbstractEventLoop:
"""获取绑定的事件循环
Returns:
asyncio.AbstractEventLoop: 事件循环
"""
return self.broadcast.loop
@overload
def get_interface(self, interface_type: Type[ConnectionInterface]) -> ConnectionInterface:
...
@overload
def get_interface(self, interface_type: type) -> None:
...
def get_interface(self, interface_type: type):
if interface_type is ConnectionInterface:
return ConnectionInterface(self)
|