Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of the sbk project
2# https://github.com/mbarkhau/sbk
3#
4# Copyright (c) 2019-2021 Manuel Barkhau (mbarkhau@gmail.com) - MIT License
5# SPDX-License-Identifier: MIT
7"""Reed Solomon style (home-grown) Forward Error Correction code.
9This may actually be a legitimate Reed Solomon encoding, it's certainly
10based on the same ideas, I'm just not sure if it qualifies. If you
11can contribute to making this conform to an appropriate standard, please
12open an issue or merge request on gitlab.
13"""
15import os
16import sys
17import math
18import base64
19import random
20import typing as typ
21import itertools
22import collections
24from . import gf
25from . import gf_poly
27# Message: Raw Data to be encoded, without ECC data
28Message = bytes
29# Block: Fully encoded Message including ECC data
30Block = bytes
31# Packet: Individual byte in a block
32Packet = int
33Packets = typ.List[Packet]
34# Erasures signified by None
35# position in the sequence implies the x-coordinate
36MaybePackets = typ.List[typ.Optional[Packet]]
39def _nCr(n: int, r: int) -> float:
40 f = math.factorial
41 return f(n) / f(r) / f(n - r)
44class ECCDecodeError(ValueError):
45 pass
48def _interpolate(points: gf_poly.Points, at_x: gf.GF256) -> gf.GF256:
49 terms_fast = list(gf_poly._interpolation_terms_256(points, at_x=at_x))
51 if os.getenv('SBK_VERIFY_ECC_RS_INTERPOLATION_TERMS', "0") == '1':
52 terms_slow = list(gf_poly._interpolation_terms(points, at_x=at_x))
53 assert terms_fast == terms_slow
55 terms = iter(terms_fast)
56 accu = next(terms)
57 for term in terms:
58 accu += term
59 return accu
62def _encode(msg: Message, ecc_len: int) -> Block:
63 if len(msg) == 0:
64 msg = b"\x00\x00"
65 if len(msg) == 1:
66 # We need at least two points (and hence bytes) to do interpolation
67 msg = msg + msg
69 field = gf.FieldGF256()
71 data_points = tuple(gf_poly.Point(field[x], field[y]) for x, y in enumerate(msg))
72 ecc_x_coords = tuple(field[x] for x in range(len(msg), len(msg) + ecc_len))
73 ecc_points = tuple(gf_poly.Point(x=x, y=_interpolate(data_points, at_x=x)) for x in ecc_x_coords)
74 y_vals = tuple(p.y.val for p in data_points + ecc_points)
76 if not all(0 <= y <= 255 for y in y_vals):
77 raise AssertionError()
79 return bytes(y_vals)
82def encode(msg: Message, ecc_len: typ.Optional[int] = None, verify: bool = True) -> Block:
83 """Encode message to a Block with RS Code as ECC data."""
84 if ecc_len is None:
85 total_len = math.ceil(len(msg) / 2) * 4
86 _ecc_len = total_len - len(msg)
87 else:
88 _ecc_len = ecc_len
90 if _ecc_len < 0:
91 raise AssertionError(str(_ecc_len))
93 if _ecc_len == 0:
94 return msg
96 block = _encode(msg, ecc_len=_ecc_len)
97 if not block.startswith(msg):
98 raise AssertionError()
100 if verify and decode(block, len(msg)) != msg:
101 raise AssertionError()
103 return block
106Indexes = typ.Tuple[int, ...]
109def _iter_indexes(msg_len: int, num_points: int) -> typ.Iterable[Indexes]:
110 assert num_points >= msg_len
112 all_indexes = tuple(range(num_points))
113 if msg_len == num_points:
114 yield all_indexes
116 num_combos = _nCr(num_points, msg_len)
117 if num_combos < 1000:
118 # few enough for exhaustive search
119 all_combos = list(itertools.combinations(all_indexes, msg_len))
120 assert len(all_combos) == _nCr(num_points, msg_len)
121 random.shuffle(all_combos)
122 for combo in all_combos:
123 yield tuple(combo)
124 else:
125 sample_combos: typ.Set[Indexes] = set()
126 while len(sample_combos) < num_combos // 3:
127 sample_combo = tuple(random.sample(all_indexes, msg_len))
128 if sample_combo not in sample_combos:
129 sample_combos.add(sample_combo)
130 yield sample_combo
133def decode_packets(packets: MaybePackets, msg_len: int) -> Message:
134 field = gf.FieldGF256()
135 points = tuple(gf_poly.Point(field[x], field[y]) for x, y in enumerate(packets) if y is not None)
137 if len(points) < msg_len:
138 raise ECCDecodeError("Not enough data to recover message.")
140 msg_x_coords = tuple(field[x] for x in range(msg_len))
141 candidates: typ.Counter[bytes] = collections.Counter()
142 for sample_num, point_indexes in enumerate(_iter_indexes(msg_len, len(points))):
143 sample_points = tuple(points[idx] for idx in point_indexes)
144 msg_candidate = bytes(_interpolate(sample_points, at_x=x).val for x in msg_x_coords)
145 candidates[msg_candidate] += 1
147 if (sample_num + 1) % 10 == 0:
148 if len(candidates) == 1:
149 ((top, _top_n),) = candidates.most_common(1)
150 return top
152 (top_0, top_0_n), (_top_1, top_1_n) = candidates.most_common(2)
154 if top_0_n > top_1_n * 10:
155 return top_0
157 if len(set(candidates)) == 1:
158 ((top, _top_n),) = candidates.most_common(1)
159 return top
161 # last ditch check
162 (top_0, top_0_n), (_top_1, top_1_n) = candidates.most_common(2)
163 if top_0_n > top_1_n * 2:
164 return top_0
166 raise ECCDecodeError("Message too corrupt to recover.")
169def decode(block: Block, msg_len: int) -> Message:
170 ecc_len = len(block) - msg_len
171 if ecc_len == 0:
172 return block
173 elif ecc_len > 0:
174 return decode_packets(list(block), msg_len)
175 else:
176 raise AssertionError(f"Invalid {ecc_len=}")
179def _cli_encode(msg: str) -> str:
180 msg_data = msg.encode("utf-8")
181 block = encode(msg_data, ecc_len=len(msg_data))
182 block_str = base64.b16encode(block).decode('ascii').lower()
183 return block_str
186def _cli_decode(block_b16_str: str) -> str:
187 block_b16 = block_b16_str.rstrip("\n").upper().encode("ascii")
188 packets: MaybePackets = []
189 for i in range(0, len(block_b16), 2):
190 packet_b16 = block_b16[i : i + 2]
191 try:
192 packet = base64.b16decode(packet_b16)[0]
193 packets.append(packet)
194 except ValueError:
195 packets.append(None)
197 assert len(packets) == len(block_b16) // 2
198 msg_data = decode_packets(packets, msg_len=len(block_b16) // 4)
199 return msg_data.decode("utf-8")
202CLI_HELP = """CLI to demo recovery using ecc.
204Example usage:
206 $ python -m sbk.ecc_rs --test
207 ...
208 $ py-spy record -r 1000 --output profile.svg -- python -m sbk.ecc_rs --test
209 ...
210 $ firefox profile.svg
212 $ python -m sbk.ecc_rs --profile
213 ...
214 $ echo "Hello, 世界!" | python -m sbk.ecc_rs --encode
215 48656c6c6f2c20e4b896e7958c210a51ee32d3ac1bee26daac14d3b95428
216 $ echo "48656c6c6f2c20e4b896e7958c210a51ee32d3ac1bee26daac14d3b95428" | python -m sbk.ecc_rs --decode
217 Hello, 世界!
218 $ echo "48656c6c6f2c20e4b896e7958c210a " | python -m sbk.ecc_rs --decode
219 Hello, 世界!
220 $ echo " 51ee32d3ac1bee26daac14d3b95428" | python -m sbk.ecc_rs --decode
221 Hello, 世界!
222 $ echo "48656c6c6f2c20 26daac14d3b95428" | python -m sbk.ecc_rs --decode
223 Hello, 世界!
224"""
227def main(args: typ.Sequence[str] = sys.argv[1:], stdin: typ.TextIO = sys.stdin) -> int:
228 # pylint: disable=dangerous-default-value; we don't modify it, I promise.
229 if "-h" in args or "--help" in args or not args:
230 print(main.__doc__)
231 return 0
233 if "--test" in args:
234 input_data = "Hello, 世界!!!!iasdf1234567890!!!!"
235 block = _cli_encode(input_data)
236 msg = _cli_decode(block)
237 assert msg == input_data
238 print("ok")
239 return 0
241 if "--profile" in args:
242 import io
243 import pstats
244 import cProfile
246 # init lookup tables
247 input_data = "Hello, 世界!!!!iasdf1234567890!!!!"
248 _cli_encode(input_data)
250 pr = cProfile.Profile()
251 pr.enable()
252 _cli_encode(input_data)
253 pr.disable()
254 s = io.StringIO()
255 ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
256 ps.print_stats()
257 print(s.getvalue())
258 return 0
260 input_data = stdin.read()
262 if "-e" in args or "--encode" in args:
263 block_b16 = _cli_encode(input_data)
264 sys.stdout.write(block_b16)
265 return 0
266 elif "-d" in args or "--decode" in args:
267 msg = _cli_decode(input_data)
268 sys.stdout.write(msg)
269 return 0
270 else:
271 sys.stderr.write("Invalid arguments\n")
272 sys.stderr.write(CLI_HELP)
273 return 1
276main.__doc__ = CLI_HELP
279if __name__ == '__main__':
280 sys.exit(main())