hat.mariner.client

  1from collections.abc import Collection
  2import asyncio
  3import contextlib
  4import itertools
  5import logging
  6import typing
  7
  8from hat import aio
  9from hat.drivers import tcp
 10import hat.event.common
 11
 12from hat.mariner import transport
 13
 14
 15mlog: logging.Logger = logging.getLogger(__name__)
 16
 17StatusCb: typing.TypeAlias = aio.AsyncCallable[
 18    ['Connection', hat.event.common.Status],
 19    None]
 20
 21EventsCb: typing.TypeAlias = aio.AsyncCallable[
 22    ['Connection', Collection[hat.event.common.Event]],
 23    None]
 24
 25
 26async def connect(addr: tcp.Address,
 27                  client_name: str,
 28                  *,
 29                  client_token: str | None = None,
 30                  subscriptions: Collection[hat.event.common.EventType] = [],
 31                  server_id: hat.event.common.ServerId | None = None,
 32                  persisted: bool = False,
 33                  ping_delay: float | None = 30,
 34                  ping_timeout: float = 30,
 35                  status_cb: StatusCb | None = None,
 36                  events_cb: EventsCb | None = None,
 37                  **kwargs
 38                  ) -> 'Connection':
 39    conn = Connection()
 40    conn._status_cb = status_cb
 41    conn._events_cb = events_cb
 42    conn._loop = asyncio.get_running_loop()
 43    conn._status = hat.event.common.Status.STANDBY
 44    conn._next_req_ids = itertools.count(1)
 45    conn._futures = {}
 46    conn._ping_event = asyncio.Event()
 47
 48    conn._conn = await transport.connect(addr, **kwargs)
 49
 50    try:
 51        init_req = transport.InitReqMsg(client_name=client_name,
 52                                        client_token=client_token,
 53                                        subscriptions=subscriptions,
 54                                        server_id=server_id,
 55                                        persisted=persisted)
 56        await conn._conn.send(init_req)
 57
 58        init_res = await conn._conn.receive()
 59        if not isinstance(init_res, transport.InitResMsg):
 60            raise Exception('invalid initiate response')
 61
 62        if not init_res.success:
 63            raise Exception('initiate error' if init_res.error is None
 64                            else f'initiate error: {init_res.error}')
 65
 66        conn._status = init_res.status
 67
 68        conn.async_group.spawn(conn._receive_loop)
 69
 70        if ping_delay:
 71            conn.async_group.spawn(conn._ping_loop, ping_delay, ping_timeout)
 72
 73    except BaseException:
 74        await aio.uncancellable(conn.async_close())
 75        raise
 76
 77    return conn
 78
 79
 80class Connection(aio.Resource):
 81
 82    @property
 83    def async_group(self) -> aio.Group:
 84        return self._conn.async_group
 85
 86    @property
 87    def status(self) -> hat.event.common.Status:
 88        return self._status
 89
 90    async def register(self,
 91                       register_events: Collection[hat.event.common.RegisterEvent] # NOQA
 92                       ) -> Collection[hat.event.common.Event] | None:
 93        register_id = next(self._next_req_ids)
 94        req = transport.RegisterReqMsg(register_id=register_id,
 95                                       register_events=register_events)
 96
 97        res = await self._send_req(req, register_id)
 98        if not isinstance(res, transport.RegisterResMsg):
 99            raise Exception('invalid register response')
100
101        return res.events
102
103    async def query(self,
104                    params: hat.event.common.QueryParams
105                    ) -> hat.event.common.QueryResult:
106        query_id = next(self._next_req_ids)
107        req = transport.QueryReqMsg(query_id=query_id,
108                                    params=params)
109
110        res = await self._send_req(req, query_id)
111        if not isinstance(res, transport.QueryResMsg):
112            raise Exception('invalid query response')
113
114        return res.result
115
116    async def _send_req(self, req, req_id):
117        if not self.is_open:
118            raise ConnectionError()
119
120        future = self._loop.create_future()
121        self._futures[req_id] = future
122
123        try:
124            await self._conn.send(req)
125
126            if not self.is_open:
127                raise ConnectionError()
128
129            return await future
130
131        finally:
132            self._futures.pop(req_id)
133
134    async def _receive_loop(self):
135        try:
136            while True:
137                msg = await self._conn.receive()
138
139                self._ping_event.set()
140
141                if isinstance(msg, transport.StatusMsg):
142                    self._status = msg.status
143
144                    if self._status_cb:
145                        await aio.call(self._status_cb, self, msg.status)
146
147                elif isinstance(msg, transport.EventsMsg):
148                    if self._events_cb:
149                        await aio.call(self._events_cb, self, msg.events)
150
151                elif isinstance(msg, transport.RegisterResMsg):
152                    future = self._futures.get(msg.register_id)
153                    if future and not future.done():
154                        future.set_result(msg)
155
156                elif isinstance(msg, transport.QueryResMsg):
157                    future = self._futures.get(msg.query_id)
158                    if future and not future.done():
159                        future.set_result(msg)
160
161                elif isinstance(msg, transport.PingResMsg):
162                    pass
163
164                else:
165                    raise Exception('invalid message type')
166
167        except ConnectionError:
168            pass
169
170        except Exception as e:
171            mlog.error("receive loop error: %s", e, exc_info=e)
172
173        finally:
174            self.close()
175
176            for future in self._futures.values():
177                if not future.done():
178                    future.set_exception(ConnectionError())
179
180    async def _ping_loop(self, delay, timeout):
181        try:
182            while True:
183                self._ping_event.clear()
184
185                with contextlib.suppress(asyncio.TimeoutError):
186                    await aio.wait_for(self._ping_event.wait(), delay)
187                    continue
188
189                req_id = next(self._next_req_ids)
190                req = transport.PingReqMsg(req_id)
191                await self._conn.send(req)
192
193                with contextlib.suppress(asyncio.TimeoutError):
194                    await aio.wait_for(self._ping_event.wait(), timeout)
195                    continue
196
197                mlog.debug("ping timeout")
198                break
199
200        except ConnectionError:
201            pass
202
203        finally:
204            self.close()
mlog: logging.Logger = <Logger hat.mariner.client (WARNING)>
StatusCb: TypeAlias = Callable[[ForwardRef('Connection'), hat.event.common.common.Status], None | Awaitable[None]]
EventsCb: TypeAlias = Callable[[ForwardRef('Connection'), Collection[hat.event.common.common.Event]], None | Awaitable[None]]
async def connect( addr: hat.drivers.tcp.Address, client_name: str, *, client_token: str | None = None, subscriptions: Collection[tuple[str, ...]] = [], server_id: int | None = None, persisted: bool = False, ping_delay: float | None = 30, ping_timeout: float = 30, status_cb: Optional[Callable[[Connection, hat.event.common.common.Status], None | Awaitable[None]]] = None, events_cb: Optional[Callable[[Connection, Collection[hat.event.common.common.Event]], None | Awaitable[None]]] = None, **kwargs) -> Connection:
27async def connect(addr: tcp.Address,
28                  client_name: str,
29                  *,
30                  client_token: str | None = None,
31                  subscriptions: Collection[hat.event.common.EventType] = [],
32                  server_id: hat.event.common.ServerId | None = None,
33                  persisted: bool = False,
34                  ping_delay: float | None = 30,
35                  ping_timeout: float = 30,
36                  status_cb: StatusCb | None = None,
37                  events_cb: EventsCb | None = None,
38                  **kwargs
39                  ) -> 'Connection':
40    conn = Connection()
41    conn._status_cb = status_cb
42    conn._events_cb = events_cb
43    conn._loop = asyncio.get_running_loop()
44    conn._status = hat.event.common.Status.STANDBY
45    conn._next_req_ids = itertools.count(1)
46    conn._futures = {}
47    conn._ping_event = asyncio.Event()
48
49    conn._conn = await transport.connect(addr, **kwargs)
50
51    try:
52        init_req = transport.InitReqMsg(client_name=client_name,
53                                        client_token=client_token,
54                                        subscriptions=subscriptions,
55                                        server_id=server_id,
56                                        persisted=persisted)
57        await conn._conn.send(init_req)
58
59        init_res = await conn._conn.receive()
60        if not isinstance(init_res, transport.InitResMsg):
61            raise Exception('invalid initiate response')
62
63        if not init_res.success:
64            raise Exception('initiate error' if init_res.error is None
65                            else f'initiate error: {init_res.error}')
66
67        conn._status = init_res.status
68
69        conn.async_group.spawn(conn._receive_loop)
70
71        if ping_delay:
72            conn.async_group.spawn(conn._ping_loop, ping_delay, ping_timeout)
73
74    except BaseException:
75        await aio.uncancellable(conn.async_close())
76        raise
77
78    return conn
class Connection(hat.aio.group.Resource):
 81class Connection(aio.Resource):
 82
 83    @property
 84    def async_group(self) -> aio.Group:
 85        return self._conn.async_group
 86
 87    @property
 88    def status(self) -> hat.event.common.Status:
 89        return self._status
 90
 91    async def register(self,
 92                       register_events: Collection[hat.event.common.RegisterEvent] # NOQA
 93                       ) -> Collection[hat.event.common.Event] | None:
 94        register_id = next(self._next_req_ids)
 95        req = transport.RegisterReqMsg(register_id=register_id,
 96                                       register_events=register_events)
 97
 98        res = await self._send_req(req, register_id)
 99        if not isinstance(res, transport.RegisterResMsg):
100            raise Exception('invalid register response')
101
102        return res.events
103
104    async def query(self,
105                    params: hat.event.common.QueryParams
106                    ) -> hat.event.common.QueryResult:
107        query_id = next(self._next_req_ids)
108        req = transport.QueryReqMsg(query_id=query_id,
109                                    params=params)
110
111        res = await self._send_req(req, query_id)
112        if not isinstance(res, transport.QueryResMsg):
113            raise Exception('invalid query response')
114
115        return res.result
116
117    async def _send_req(self, req, req_id):
118        if not self.is_open:
119            raise ConnectionError()
120
121        future = self._loop.create_future()
122        self._futures[req_id] = future
123
124        try:
125            await self._conn.send(req)
126
127            if not self.is_open:
128                raise ConnectionError()
129
130            return await future
131
132        finally:
133            self._futures.pop(req_id)
134
135    async def _receive_loop(self):
136        try:
137            while True:
138                msg = await self._conn.receive()
139
140                self._ping_event.set()
141
142                if isinstance(msg, transport.StatusMsg):
143                    self._status = msg.status
144
145                    if self._status_cb:
146                        await aio.call(self._status_cb, self, msg.status)
147
148                elif isinstance(msg, transport.EventsMsg):
149                    if self._events_cb:
150                        await aio.call(self._events_cb, self, msg.events)
151
152                elif isinstance(msg, transport.RegisterResMsg):
153                    future = self._futures.get(msg.register_id)
154                    if future and not future.done():
155                        future.set_result(msg)
156
157                elif isinstance(msg, transport.QueryResMsg):
158                    future = self._futures.get(msg.query_id)
159                    if future and not future.done():
160                        future.set_result(msg)
161
162                elif isinstance(msg, transport.PingResMsg):
163                    pass
164
165                else:
166                    raise Exception('invalid message type')
167
168        except ConnectionError:
169            pass
170
171        except Exception as e:
172            mlog.error("receive loop error: %s", e, exc_info=e)
173
174        finally:
175            self.close()
176
177            for future in self._futures.values():
178                if not future.done():
179                    future.set_exception(ConnectionError())
180
181    async def _ping_loop(self, delay, timeout):
182        try:
183            while True:
184                self._ping_event.clear()
185
186                with contextlib.suppress(asyncio.TimeoutError):
187                    await aio.wait_for(self._ping_event.wait(), delay)
188                    continue
189
190                req_id = next(self._next_req_ids)
191                req = transport.PingReqMsg(req_id)
192                await self._conn.send(req)
193
194                with contextlib.suppress(asyncio.TimeoutError):
195                    await aio.wait_for(self._ping_event.wait(), timeout)
196                    continue
197
198                mlog.debug("ping timeout")
199                break
200
201        except ConnectionError:
202            pass
203
204        finally:
205            self.close()

Resource with lifetime control based on Group.

async_group: hat.aio.group.Group
83    @property
84    def async_group(self) -> aio.Group:
85        return self._conn.async_group

Group controlling resource's lifetime.

status: hat.event.common.common.Status
87    @property
88    def status(self) -> hat.event.common.Status:
89        return self._status
async def register( self, register_events: Collection[hat.event.common.common.RegisterEvent]) -> Collection[hat.event.common.common.Event] | None:
 91    async def register(self,
 92                       register_events: Collection[hat.event.common.RegisterEvent] # NOQA
 93                       ) -> Collection[hat.event.common.Event] | None:
 94        register_id = next(self._next_req_ids)
 95        req = transport.RegisterReqMsg(register_id=register_id,
 96                                       register_events=register_events)
 97
 98        res = await self._send_req(req, register_id)
 99        if not isinstance(res, transport.RegisterResMsg):
100            raise Exception('invalid register response')
101
102        return res.events

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

async def query( self, params: hat.event.common.common.QueryLatestParams | hat.event.common.common.QueryTimeseriesParams | hat.event.common.common.QueryServerParams) -> hat.event.common.common.QueryResult:
104    async def query(self,
105                    params: hat.event.common.QueryParams
106                    ) -> hat.event.common.QueryResult:
107        query_id = next(self._next_req_ids)
108        req = transport.QueryReqMsg(query_id=query_id,
109                                    params=params)
110
111        res = await self._send_req(req, query_id)
112        if not isinstance(res, transport.QueryResMsg):
113            raise Exception('invalid query response')
114
115        return res.result