438 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			438 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import bisect
 | |
| import re
 | |
| import unicodedata
 | |
| from typing import Optional, Union
 | |
| 
 | |
| from . import idnadata
 | |
| from .intranges import intranges_contain
 | |
| 
 | |
| _virama_combining_class = 9
 | |
| _alabel_prefix = b"xn--"
 | |
| _unicode_dots_re = re.compile("[\u002e\u3002\uff0e\uff61]")
 | |
| 
 | |
| 
 | |
| class IDNAError(UnicodeError):
 | |
|     """Base exception for all IDNA-encoding related problems"""
 | |
| 
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class IDNABidiError(IDNAError):
 | |
|     """Exception when bidirectional requirements are not satisfied"""
 | |
| 
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class InvalidCodepoint(IDNAError):
 | |
|     """Exception when a disallowed or unallocated codepoint is used"""
 | |
| 
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class InvalidCodepointContext(IDNAError):
 | |
|     """Exception when the codepoint is not valid in the context it is used"""
 | |
| 
 | |
|     pass
 | |
| 
 | |
| 
 | |
| def _combining_class(cp: int) -> int:
 | |
|     v = unicodedata.combining(chr(cp))
 | |
|     if v == 0:
 | |
|         if not unicodedata.name(chr(cp)):
 | |
|             raise ValueError("Unknown character in unicodedata")
 | |
|     return v
 | |
| 
 | |
| 
 | |
| def _is_script(cp: str, script: str) -> bool:
 | |
|     return intranges_contain(ord(cp), idnadata.scripts[script])
 | |
| 
 | |
| 
 | |
| def _punycode(s: str) -> bytes:
 | |
|     return s.encode("punycode")
 | |
| 
 | |
| 
 | |
| def _unot(s: int) -> str:
 | |
|     return "U+{:04X}".format(s)
 | |
| 
 | |
| 
 | |
| def valid_label_length(label: Union[bytes, str]) -> bool:
 | |
|     if len(label) > 63:
 | |
|         return False
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool:
 | |
|     if len(label) > (254 if trailing_dot else 253):
 | |
|         return False
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def check_bidi(label: str, check_ltr: bool = False) -> bool:
 | |
|     # Bidi rules should only be applied if string contains RTL characters
 | |
|     bidi_label = False
 | |
|     for idx, cp in enumerate(label, 1):
 | |
|         direction = unicodedata.bidirectional(cp)
 | |
|         if direction == "":
 | |
|             # String likely comes from a newer version of Unicode
 | |
|             raise IDNABidiError("Unknown directionality in label {} at position {}".format(repr(label), idx))
 | |
|         if direction in ["R", "AL", "AN"]:
 | |
|             bidi_label = True
 | |
|     if not bidi_label and not check_ltr:
 | |
|         return True
 | |
| 
 | |
|     # Bidi rule 1
 | |
|     direction = unicodedata.bidirectional(label[0])
 | |
|     if direction in ["R", "AL"]:
 | |
|         rtl = True
 | |
|     elif direction == "L":
 | |
|         rtl = False
 | |
|     else:
 | |
|         raise IDNABidiError("First codepoint in label {} must be directionality L, R or AL".format(repr(label)))
 | |
| 
 | |
|     valid_ending = False
 | |
|     number_type: Optional[str] = None
 | |
|     for idx, cp in enumerate(label, 1):
 | |
|         direction = unicodedata.bidirectional(cp)
 | |
| 
 | |
|         if rtl:
 | |
|             # Bidi rule 2
 | |
|             if direction not in [
 | |
|                 "R",
 | |
|                 "AL",
 | |
|                 "AN",
 | |
|                 "EN",
 | |
|                 "ES",
 | |
|                 "CS",
 | |
|                 "ET",
 | |
|                 "ON",
 | |
|                 "BN",
 | |
|                 "NSM",
 | |
|             ]:
 | |
|                 raise IDNABidiError("Invalid direction for codepoint at position {} in a right-to-left label".format(idx))
 | |
|             # Bidi rule 3
 | |
|             if direction in ["R", "AL", "EN", "AN"]:
 | |
|                 valid_ending = True
 | |
|             elif direction != "NSM":
 | |
|                 valid_ending = False
 | |
|             # Bidi rule 4
 | |
|             if direction in ["AN", "EN"]:
 | |
|                 if not number_type:
 | |
|                     number_type = direction
 | |
|                 else:
 | |
|                     if number_type != direction:
 | |
|                         raise IDNABidiError("Can not mix numeral types in a right-to-left label")
 | |
|         else:
 | |
|             # Bidi rule 5
 | |
|             if direction not in ["L", "EN", "ES", "CS", "ET", "ON", "BN", "NSM"]:
 | |
|                 raise IDNABidiError("Invalid direction for codepoint at position {} in a left-to-right label".format(idx))
 | |
|             # Bidi rule 6
 | |
|             if direction in ["L", "EN"]:
 | |
|                 valid_ending = True
 | |
|             elif direction != "NSM":
 | |
|                 valid_ending = False
 | |
| 
 | |
|     if not valid_ending:
 | |
|         raise IDNABidiError("Label ends with illegal codepoint directionality")
 | |
| 
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def check_initial_combiner(label: str) -> bool:
 | |
|     if unicodedata.category(label[0])[0] == "M":
 | |
|         raise IDNAError("Label begins with an illegal combining character")
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def check_hyphen_ok(label: str) -> bool:
 | |
|     if label[2:4] == "--":
 | |
|         raise IDNAError("Label has disallowed hyphens in 3rd and 4th position")
 | |
|     if label[0] == "-" or label[-1] == "-":
 | |
|         raise IDNAError("Label must not start or end with a hyphen")
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def check_nfc(label: str) -> None:
 | |
|     if unicodedata.normalize("NFC", label) != label:
 | |
|         raise IDNAError("Label must be in Normalization Form C")
 | |
| 
 | |
| 
 | |
| def valid_contextj(label: str, pos: int) -> bool:
 | |
|     cp_value = ord(label[pos])
 | |
| 
 | |
|     if cp_value == 0x200C:
 | |
|         if pos > 0:
 | |
|             if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
 | |
|                 return True
 | |
| 
 | |
|         ok = False
 | |
|         for i in range(pos - 1, -1, -1):
 | |
|             joining_type = idnadata.joining_types.get(ord(label[i]))
 | |
|             if joining_type == ord("T"):
 | |
|                 continue
 | |
|             elif joining_type in [ord("L"), ord("D")]:
 | |
|                 ok = True
 | |
|                 break
 | |
|             else:
 | |
|                 break
 | |
| 
 | |
|         if not ok:
 | |
|             return False
 | |
| 
 | |
|         ok = False
 | |
|         for i in range(pos + 1, len(label)):
 | |
|             joining_type = idnadata.joining_types.get(ord(label[i]))
 | |
|             if joining_type == ord("T"):
 | |
|                 continue
 | |
|             elif joining_type in [ord("R"), ord("D")]:
 | |
|                 ok = True
 | |
|                 break
 | |
|             else:
 | |
|                 break
 | |
|         return ok
 | |
| 
 | |
|     if cp_value == 0x200D:
 | |
|         if pos > 0:
 | |
|             if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
 | |
|                 return True
 | |
|         return False
 | |
| 
 | |
|     else:
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def valid_contexto(label: str, pos: int, exception: bool = False) -> bool:
 | |
|     cp_value = ord(label[pos])
 | |
| 
 | |
|     if cp_value == 0x00B7:
 | |
|         if 0 < pos < len(label) - 1:
 | |
|             if ord(label[pos - 1]) == 0x006C and ord(label[pos + 1]) == 0x006C:
 | |
|                 return True
 | |
|         return False
 | |
| 
 | |
|     elif cp_value == 0x0375:
 | |
|         if pos < len(label) - 1 and len(label) > 1:
 | |
|             return _is_script(label[pos + 1], "Greek")
 | |
|         return False
 | |
| 
 | |
|     elif cp_value == 0x05F3 or cp_value == 0x05F4:
 | |
|         if pos > 0:
 | |
|             return _is_script(label[pos - 1], "Hebrew")
 | |
|         return False
 | |
| 
 | |
|     elif cp_value == 0x30FB:
 | |
|         for cp in label:
 | |
|             if cp == "\u30fb":
 | |
|                 continue
 | |
|             if _is_script(cp, "Hiragana") or _is_script(cp, "Katakana") or _is_script(cp, "Han"):
 | |
|                 return True
 | |
|         return False
 | |
| 
 | |
|     elif 0x660 <= cp_value <= 0x669:
 | |
|         for cp in label:
 | |
|             if 0x6F0 <= ord(cp) <= 0x06F9:
 | |
|                 return False
 | |
|         return True
 | |
| 
 | |
|     elif 0x6F0 <= cp_value <= 0x6F9:
 | |
|         for cp in label:
 | |
|             if 0x660 <= ord(cp) <= 0x0669:
 | |
|                 return False
 | |
|         return True
 | |
| 
 | |
|     return False
 | |
| 
 | |
| 
 | |
| def check_label(label: Union[str, bytes, bytearray]) -> None:
 | |
|     if isinstance(label, (bytes, bytearray)):
 | |
|         label = label.decode("utf-8")
 | |
|     if len(label) == 0:
 | |
|         raise IDNAError("Empty Label")
 | |
| 
 | |
|     check_nfc(label)
 | |
|     check_hyphen_ok(label)
 | |
|     check_initial_combiner(label)
 | |
| 
 | |
|     for pos, cp in enumerate(label):
 | |
|         cp_value = ord(cp)
 | |
|         if intranges_contain(cp_value, idnadata.codepoint_classes["PVALID"]):
 | |
|             continue
 | |
|         elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTJ"]):
 | |
|             try:
 | |
|                 if not valid_contextj(label, pos):
 | |
|                     raise InvalidCodepointContext(
 | |
|                         "Joiner {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label))
 | |
|                     )
 | |
|             except ValueError:
 | |
|                 raise IDNAError(
 | |
|                     "Unknown codepoint adjacent to joiner {} at position {} in {}".format(
 | |
|                         _unot(cp_value), pos + 1, repr(label)
 | |
|                     )
 | |
|                 )
 | |
|         elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTO"]):
 | |
|             if not valid_contexto(label, pos):
 | |
|                 raise InvalidCodepointContext(
 | |
|                     "Codepoint {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label))
 | |
|                 )
 | |
|         else:
 | |
|             raise InvalidCodepoint(
 | |
|                 "Codepoint {} at position {} of {} not allowed".format(_unot(cp_value), pos + 1, repr(label))
 | |
|             )
 | |
| 
 | |
|     check_bidi(label)
 | |
| 
 | |
| 
 | |
| def alabel(label: str) -> bytes:
 | |
|     try:
 | |
|         label_bytes = label.encode("ascii")
 | |
|         ulabel(label_bytes)
 | |
|         if not valid_label_length(label_bytes):
 | |
|             raise IDNAError("Label too long")
 | |
|         return label_bytes
 | |
|     except UnicodeEncodeError:
 | |
|         pass
 | |
| 
 | |
|     check_label(label)
 | |
|     label_bytes = _alabel_prefix + _punycode(label)
 | |
| 
 | |
|     if not valid_label_length(label_bytes):
 | |
|         raise IDNAError("Label too long")
 | |
| 
 | |
|     return label_bytes
 | |
| 
 | |
| 
 | |
| def ulabel(label: Union[str, bytes, bytearray]) -> str:
 | |
|     if not isinstance(label, (bytes, bytearray)):
 | |
|         try:
 | |
|             label_bytes = label.encode("ascii")
 | |
|         except UnicodeEncodeError:
 | |
|             check_label(label)
 | |
|             return label
 | |
|     else:
 | |
|         label_bytes = label
 | |
| 
 | |
|     label_bytes = label_bytes.lower()
 | |
|     if label_bytes.startswith(_alabel_prefix):
 | |
|         label_bytes = label_bytes[len(_alabel_prefix) :]
 | |
|         if not label_bytes:
 | |
|             raise IDNAError("Malformed A-label, no Punycode eligible content found")
 | |
|         if label_bytes.decode("ascii")[-1] == "-":
 | |
|             raise IDNAError("A-label must not end with a hyphen")
 | |
|     else:
 | |
|         check_label(label_bytes)
 | |
|         return label_bytes.decode("ascii")
 | |
| 
 | |
|     try:
 | |
|         label = label_bytes.decode("punycode")
 | |
|     except UnicodeError:
 | |
|         raise IDNAError("Invalid A-label")
 | |
|     check_label(label)
 | |
|     return label
 | |
| 
 | |
| 
 | |
| def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str:
 | |
|     """Re-map the characters in the string according to UTS46 processing."""
 | |
|     from .uts46data import uts46data
 | |
| 
 | |
|     output = ""
 | |
| 
 | |
|     for pos, char in enumerate(domain):
 | |
|         code_point = ord(char)
 | |
|         try:
 | |
|             uts46row = uts46data[code_point if code_point < 256 else bisect.bisect_left(uts46data, (code_point, "Z")) - 1]
 | |
|             status = uts46row[1]
 | |
|             replacement: Optional[str] = None
 | |
|             if len(uts46row) == 3:
 | |
|                 replacement = uts46row[2]
 | |
|             if (
 | |
|                 status == "V"
 | |
|                 or (status == "D" and not transitional)
 | |
|                 or (status == "3" and not std3_rules and replacement is None)
 | |
|             ):
 | |
|                 output += char
 | |
|             elif replacement is not None and (
 | |
|                 status == "M" or (status == "3" and not std3_rules) or (status == "D" and transitional)
 | |
|             ):
 | |
|                 output += replacement
 | |
|             elif status != "I":
 | |
|                 raise IndexError()
 | |
|         except IndexError:
 | |
|             raise InvalidCodepoint(
 | |
|                 "Codepoint {} not allowed at position {} in {}".format(_unot(code_point), pos + 1, repr(domain))
 | |
|             )
 | |
| 
 | |
|     return unicodedata.normalize("NFC", output)
 | |
| 
 | |
| 
 | |
| def encode(
 | |
|     s: Union[str, bytes, bytearray],
 | |
|     strict: bool = False,
 | |
|     uts46: bool = False,
 | |
|     std3_rules: bool = False,
 | |
|     transitional: bool = False,
 | |
| ) -> bytes:
 | |
|     if not isinstance(s, str):
 | |
|         try:
 | |
|             s = str(s, "ascii")
 | |
|         except UnicodeDecodeError:
 | |
|             raise IDNAError("should pass a unicode string to the function rather than a byte string.")
 | |
|     if uts46:
 | |
|         s = uts46_remap(s, std3_rules, transitional)
 | |
|     trailing_dot = False
 | |
|     result = []
 | |
|     if strict:
 | |
|         labels = s.split(".")
 | |
|     else:
 | |
|         labels = _unicode_dots_re.split(s)
 | |
|     if not labels or labels == [""]:
 | |
|         raise IDNAError("Empty domain")
 | |
|     if labels[-1] == "":
 | |
|         del labels[-1]
 | |
|         trailing_dot = True
 | |
|     for label in labels:
 | |
|         s = alabel(label)
 | |
|         if s:
 | |
|             result.append(s)
 | |
|         else:
 | |
|             raise IDNAError("Empty label")
 | |
|     if trailing_dot:
 | |
|         result.append(b"")
 | |
|     s = b".".join(result)
 | |
|     if not valid_string_length(s, trailing_dot):
 | |
|         raise IDNAError("Domain too long")
 | |
|     return s
 | |
| 
 | |
| 
 | |
| def decode(
 | |
|     s: Union[str, bytes, bytearray],
 | |
|     strict: bool = False,
 | |
|     uts46: bool = False,
 | |
|     std3_rules: bool = False,
 | |
| ) -> str:
 | |
|     try:
 | |
|         if not isinstance(s, str):
 | |
|             s = str(s, "ascii")
 | |
|     except UnicodeDecodeError:
 | |
|         raise IDNAError("Invalid ASCII in A-label")
 | |
|     if uts46:
 | |
|         s = uts46_remap(s, std3_rules, False)
 | |
|     trailing_dot = False
 | |
|     result = []
 | |
|     if not strict:
 | |
|         labels = _unicode_dots_re.split(s)
 | |
|     else:
 | |
|         labels = s.split(".")
 | |
|     if not labels or labels == [""]:
 | |
|         raise IDNAError("Empty domain")
 | |
|     if not labels[-1]:
 | |
|         del labels[-1]
 | |
|         trailing_dot = True
 | |
|     for label in labels:
 | |
|         s = ulabel(label)
 | |
|         if s:
 | |
|             result.append(s)
 | |
|         else:
 | |
|             raise IDNAError("Empty label")
 | |
|     if trailing_dot:
 | |
|         result.append("")
 | |
|     return ".".join(result)
 |