Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of the sbk project 

2# https://github.com/mbarkhau/sbk 

3# 

4# Copyright (c) 2019-2021 Manuel Barkhau (mbarkhau@gmail.com) - MIT License 

5# SPDX-License-Identifier: MIT 

6 

7"""Reed Solomon style (home-grown) Forward Error Correction code. 

8 

9This may actually be a legitimate Reed Solomon encoding, it's certainly 

10based on the same ideas, I'm just not sure if it qualifies. If you 

11can contribute to making this conform to an appropriate standard, please 

12open an issue or merge request on gitlab. 

13""" 

14 

15import os 

16import sys 

17import math 

18import base64 

19import random 

20import typing as typ 

21import itertools 

22import collections 

23 

24from . import gf 

25from . import gf_poly 

26 

27# Message: Raw Data to be encoded, without ECC data 

28Message = bytes 

29# Block: Fully encoded Message including ECC data 

30Block = bytes 

31# Packet: Individual byte in a block 

32Packet = int 

33Packets = typ.List[Packet] 

34# Erasures signified by None 

35# position in the sequence implies the x-coordinate 

36MaybePackets = typ.List[typ.Optional[Packet]] 

37 

38 

39def _nCr(n: int, r: int) -> float: 

40 f = math.factorial 

41 return f(n) / f(r) / f(n - r) 

42 

43 

44class ECCDecodeError(ValueError): 

45 pass 

46 

47 

48def _interpolate(points: gf_poly.Points, at_x: gf.GF256) -> gf.GF256: 

49 terms_fast = list(gf_poly._interpolation_terms_256(points, at_x=at_x)) 

50 

51 if os.getenv('SBK_VERIFY_ECC_RS_INTERPOLATION_TERMS', "0") == '1': 

52 terms_slow = list(gf_poly._interpolation_terms(points, at_x=at_x)) 

53 assert terms_fast == terms_slow 

54 

55 terms = iter(terms_fast) 

56 accu = next(terms) 

57 for term in terms: 

58 accu += term 

59 return accu 

60 

61 

62def _encode(msg: Message, ecc_len: int) -> Block: 

63 if len(msg) == 0: 

64 msg = b"\x00\x00" 

65 if len(msg) == 1: 

66 # We need at least two points (and hence bytes) to do interpolation 

67 msg = msg + msg 

68 

69 field = gf.FieldGF256() 

70 

71 data_points = tuple(gf_poly.Point(field[x], field[y]) for x, y in enumerate(msg)) 

72 ecc_x_coords = tuple(field[x] for x in range(len(msg), len(msg) + ecc_len)) 

73 ecc_points = tuple(gf_poly.Point(x=x, y=_interpolate(data_points, at_x=x)) for x in ecc_x_coords) 

74 y_vals = tuple(p.y.val for p in data_points + ecc_points) 

75 

76 if not all(0 <= y <= 255 for y in y_vals): 

77 raise AssertionError() 

78 

79 return bytes(y_vals) 

80 

81 

82def encode(msg: Message, ecc_len: typ.Optional[int] = None, verify: bool = True) -> Block: 

83 """Encode message to a Block with RS Code as ECC data.""" 

84 if ecc_len is None: 

85 total_len = math.ceil(len(msg) / 2) * 4 

86 _ecc_len = total_len - len(msg) 

87 else: 

88 _ecc_len = ecc_len 

89 

90 if _ecc_len < 0: 

91 raise AssertionError(str(_ecc_len)) 

92 

93 if _ecc_len == 0: 

94 return msg 

95 

96 block = _encode(msg, ecc_len=_ecc_len) 

97 if not block.startswith(msg): 

98 raise AssertionError() 

99 

100 if verify and decode(block, len(msg)) != msg: 

101 raise AssertionError() 

102 

103 return block 

104 

105 

106Indexes = typ.Tuple[int, ...] 

107 

108 

109def _iter_indexes(msg_len: int, num_points: int) -> typ.Iterable[Indexes]: 

110 assert num_points >= msg_len 

111 

112 all_indexes = tuple(range(num_points)) 

113 if msg_len == num_points: 

114 yield all_indexes 

115 

116 num_combos = _nCr(num_points, msg_len) 

117 if num_combos < 1000: 

118 # few enough for exhaustive search 

119 all_combos = list(itertools.combinations(all_indexes, msg_len)) 

120 assert len(all_combos) == _nCr(num_points, msg_len) 

121 random.shuffle(all_combos) 

122 for combo in all_combos: 

123 yield tuple(combo) 

124 else: 

125 sample_combos: typ.Set[Indexes] = set() 

126 while len(sample_combos) < num_combos // 3: 

127 sample_combo = tuple(random.sample(all_indexes, msg_len)) 

128 if sample_combo not in sample_combos: 

129 sample_combos.add(sample_combo) 

130 yield sample_combo 

131 

132 

133def decode_packets(packets: MaybePackets, msg_len: int) -> Message: 

134 field = gf.FieldGF256() 

135 points = tuple(gf_poly.Point(field[x], field[y]) for x, y in enumerate(packets) if y is not None) 

136 

137 if len(points) < msg_len: 

138 raise ECCDecodeError("Not enough data to recover message.") 

139 

140 msg_x_coords = tuple(field[x] for x in range(msg_len)) 

141 candidates: typ.Counter[bytes] = collections.Counter() 

142 for sample_num, point_indexes in enumerate(_iter_indexes(msg_len, len(points))): 

143 sample_points = tuple(points[idx] for idx in point_indexes) 

144 msg_candidate = bytes(_interpolate(sample_points, at_x=x).val for x in msg_x_coords) 

145 candidates[msg_candidate] += 1 

146 

147 if (sample_num + 1) % 10 == 0: 

148 if len(candidates) == 1: 

149 ((top, _top_n),) = candidates.most_common(1) 

150 return top 

151 

152 (top_0, top_0_n), (_top_1, top_1_n) = candidates.most_common(2) 

153 

154 if top_0_n > top_1_n * 10: 

155 return top_0 

156 

157 if len(set(candidates)) == 1: 

158 ((top, _top_n),) = candidates.most_common(1) 

159 return top 

160 

161 # last ditch check 

162 (top_0, top_0_n), (_top_1, top_1_n) = candidates.most_common(2) 

163 if top_0_n > top_1_n * 2: 

164 return top_0 

165 

166 raise ECCDecodeError("Message too corrupt to recover.") 

167 

168 

169def decode(block: Block, msg_len: int) -> Message: 

170 ecc_len = len(block) - msg_len 

171 if ecc_len == 0: 

172 return block 

173 elif ecc_len > 0: 

174 return decode_packets(list(block), msg_len) 

175 else: 

176 raise AssertionError(f"Invalid {ecc_len=}") 

177 

178 

179def _cli_encode(msg: str) -> str: 

180 msg_data = msg.encode("utf-8") 

181 block = encode(msg_data, ecc_len=len(msg_data)) 

182 block_str = base64.b16encode(block).decode('ascii').lower() 

183 return block_str 

184 

185 

186def _cli_decode(block_b16_str: str) -> str: 

187 block_b16 = block_b16_str.rstrip("\n").upper().encode("ascii") 

188 packets: MaybePackets = [] 

189 for i in range(0, len(block_b16), 2): 

190 packet_b16 = block_b16[i : i + 2] 

191 try: 

192 packet = base64.b16decode(packet_b16)[0] 

193 packets.append(packet) 

194 except ValueError: 

195 packets.append(None) 

196 

197 assert len(packets) == len(block_b16) // 2 

198 msg_data = decode_packets(packets, msg_len=len(block_b16) // 4) 

199 return msg_data.decode("utf-8") 

200 

201 

202CLI_HELP = """CLI to demo recovery using ecc. 

203 

204Example usage: 

205 

206 $ python -m sbk.ecc_rs --test 

207 ... 

208 $ py-spy record -r 1000 --output profile.svg -- python -m sbk.ecc_rs --test 

209 ... 

210 $ firefox profile.svg 

211 

212 $ python -m sbk.ecc_rs --profile 

213 ... 

214 $ echo "Hello, 世界!" | python -m sbk.ecc_rs --encode 

215 48656c6c6f2c20e4b896e7958c210a51ee32d3ac1bee26daac14d3b95428 

216 $ echo "48656c6c6f2c20e4b896e7958c210a51ee32d3ac1bee26daac14d3b95428" | python -m sbk.ecc_rs --decode 

217 Hello, 世界! 

218 $ echo "48656c6c6f2c20e4b896e7958c210a " | python -m sbk.ecc_rs --decode 

219 Hello, 世界! 

220 $ echo " 51ee32d3ac1bee26daac14d3b95428" | python -m sbk.ecc_rs --decode 

221 Hello, 世界! 

222 $ echo "48656c6c6f2c20 26daac14d3b95428" | python -m sbk.ecc_rs --decode 

223 Hello, 世界! 

224""" 

225 

226 

227def main(args: typ.Sequence[str] = sys.argv[1:], stdin: typ.TextIO = sys.stdin) -> int: 

228 # pylint: disable=dangerous-default-value; we don't modify it, I promise. 

229 if "-h" in args or "--help" in args or not args: 

230 print(main.__doc__) 

231 return 0 

232 

233 if "--test" in args: 

234 input_data = "Hello, 世界!!!!iasdf1234567890!!!!" 

235 block = _cli_encode(input_data) 

236 msg = _cli_decode(block) 

237 assert msg == input_data 

238 print("ok") 

239 return 0 

240 

241 if "--profile" in args: 

242 import io 

243 import pstats 

244 import cProfile 

245 

246 # init lookup tables 

247 input_data = "Hello, 世界!!!!iasdf1234567890!!!!" 

248 _cli_encode(input_data) 

249 

250 pr = cProfile.Profile() 

251 pr.enable() 

252 _cli_encode(input_data) 

253 pr.disable() 

254 s = io.StringIO() 

255 ps = pstats.Stats(pr, stream=s).sort_stats('cumulative') 

256 ps.print_stats() 

257 print(s.getvalue()) 

258 return 0 

259 

260 input_data = stdin.read() 

261 

262 if "-e" in args or "--encode" in args: 

263 block_b16 = _cli_encode(input_data) 

264 sys.stdout.write(block_b16) 

265 return 0 

266 elif "-d" in args or "--decode" in args: 

267 msg = _cli_decode(input_data) 

268 sys.stdout.write(msg) 

269 return 0 

270 else: 

271 sys.stderr.write("Invalid arguments\n") 

272 sys.stderr.write(CLI_HELP) 

273 return 1 

274 

275 

276main.__doc__ = CLI_HELP 

277 

278 

279if __name__ == '__main__': 

280 sys.exit(main())