Module moody.m.kyc.merkletree

Expand source code
from typing import List, Any, Callable, Dict, Optional, Union
from hexbytes import HexBytes

import hashlib


class MerkleTree:
    """
    Merkle Tree Lx.

    """

    def __init__(
            self, leaves: List[Any], hash_function: Callable, sort: Optional[bool] = False
    ) -> None:
        self.hash_function = self.bufferify_function(hash_function)
        self.leaves = leaves
        self.sort_leaves = sort
        self.sort_pairs = sort
        self._process_leaves()

    @staticmethod
    def to_hex(value: Any) -> str:
        return value.hex()

    @staticmethod
    def bufferify(value: str) -> str:
        if type(value) == bytes:
            return value
        else:
            return value.encode()

    @staticmethod
    def bufferify_function(func: Callable) -> Callable:
        def f(val):
            return MerkleTree.bufferify(func(val))

        return f

    def _process_leaves(self) -> None:
        self.leaves = [self.bufferify(leaf) for leaf in self.leaves]
        if self.sort_leaves:
            self.leaves.sort()
        self.layers = [self.leaves]
        self._create_hashes(self.leaves)

    def _create_hashes(self, nodes: List[Any]) -> None:
        while len(nodes) > 1:
            n = len(nodes)
            layer_index = len(self.layers)
            self.layers.append([])
            for i in range(0, n, 2):
                if n == i + 1 and n % 2 == 1:
                    self.layers[layer_index].append(nodes[i])
                    continue
                left = self.bufferify(nodes[i])
                right = left if i + 1 == n else self.bufferify(nodes[i + 1])
                combined = (
                    left + right
                    if not (self.sort_pairs and right > left)
                    else right + left
                )
                hashed_data = self.hash_function(combined)
                self.layers[layer_index].append(hashed_data)
            nodes = self.layers[layer_index]

    def get_hex_layers(self) -> List[str]:
        return [[self.to_hex(leaf) for leaf in layer] for layer in self.layers]

    def get_root(self) -> any:
        try:
            return self.layers[-1][0]
        except IndexError:
            return []

    def get_hex_root(self) -> any:
        if self.get_root():
            return self.to_hex(self.get_root())
        else:
            return []

    def get_proof(self, leaf: Union[str, bytes], index: Optional[int] = None) -> List[Dict[str, Any]]:
        proof = []
        leaf = self.bufferify(leaf)
        if not index:
            try:
                index = self.leaves.index(leaf)
            except ValueError:
                return []

        for layer in self.layers:
            is_right_node = index % 2
            pair_index = index - 1 if is_right_node else index + 1
            if pair_index < len(layer):
                proof.append(
                    {
                        "position": "left" if is_right_node else "right",
                        "data": layer[pair_index],
                    }
                )
            index = index // 2
        return proof

    def get_hex_proof(self, leaf: Union[str, bytes], index: Optional[int] = 0) -> List[str]:
        return [self.to_hex(item["data"]) for item in self.get_proof(leaf, index)]

    def verify(self, proof: List[Dict[str, str]], target_node: str, root: str) -> bool:
        hash = self.bufferify(target_node)
        root = self.bufferify(root)
        for node in proof:
            data = self.bufferify(node["data"])
            is_left_node = node["position"] == "left"
            if self.sort_pairs:
                combined = data + hash if data <= hash else hash + data
            else:
                combined = data + hash if is_left_node else hash + data
            hash = self.hash_function(combined)
        return hash == root

    def get_depth(self) -> int:
        return len(self.layers) - 1

    def reset_tree(self) -> None:
        self.leaves = []
        self.layers = []


def sha256(x) -> bytes:
    return hashlib.sha256(x).digest()


"""
leaves = [sha256(leaf.encode()) for leaf in "abc"]
tree = MerkleTree(leaves, sha256)
root = tree.get_root()
hex_root = tree.get_hex_root()
leaf = sha256("a".encode())
bad_leaf = sha256("x".encode())
proof = tree.get_proof(leaf)
ok = tree.verify(proof, leaf, root)  # returns True
ok2 = tree.verify(proof, bad_leaf, root)  # returns False
print(hex_root, tree.get_hex_proof(leaf), ok, ok2)

"""


class Kyc:
    def __init__(self):
        self.data = list()
        self._tree = None
        self.root_hex = None
        self.root_layers = None

    def updateList(self, address_list: List[str]) -> "Kyc":
        self.data = address_list
        self._tree = MerkleTree([sha256(leaf.encode()) for leaf in address_list], sha256)
        self.root_hex = self._tree.get_hex_root()
        self.root_layers = self._tree.get_hex_layers()
        return self

    def getKycDataFromAddress(self, address: str) -> list:
        if address not in self.data:
            print(f"Address {address} is not inside the collection")
            return list()
        if self._tree is None:
            print("tree object is not initialized, please updateList the collection first.")
            return list()

        leaf = sha256(address.encode())
        kyc_proof = self._tree.get_hex_proof(leaf)
        proof = self._tree.get_proof(leaf)
        proof_ok = self._tree.verify(proof, leaf, self._tree.get_root())
        print(f"KYC proof ok? {proof_ok}")

        return kyc_proof

Functions

def sha256(x) ‑> bytes
Expand source code
def sha256(x) -> bytes:
    return hashlib.sha256(x).digest()

Classes

class Kyc
Expand source code
class Kyc:
    def __init__(self):
        self.data = list()
        self._tree = None
        self.root_hex = None
        self.root_layers = None

    def updateList(self, address_list: List[str]) -> "Kyc":
        self.data = address_list
        self._tree = MerkleTree([sha256(leaf.encode()) for leaf in address_list], sha256)
        self.root_hex = self._tree.get_hex_root()
        self.root_layers = self._tree.get_hex_layers()
        return self

    def getKycDataFromAddress(self, address: str) -> list:
        if address not in self.data:
            print(f"Address {address} is not inside the collection")
            return list()
        if self._tree is None:
            print("tree object is not initialized, please updateList the collection first.")
            return list()

        leaf = sha256(address.encode())
        kyc_proof = self._tree.get_hex_proof(leaf)
        proof = self._tree.get_proof(leaf)
        proof_ok = self._tree.verify(proof, leaf, self._tree.get_root())
        print(f"KYC proof ok? {proof_ok}")

        return kyc_proof

Methods

def getKycDataFromAddress(self, address: str) ‑> list
Expand source code
def getKycDataFromAddress(self, address: str) -> list:
    if address not in self.data:
        print(f"Address {address} is not inside the collection")
        return list()
    if self._tree is None:
        print("tree object is not initialized, please updateList the collection first.")
        return list()

    leaf = sha256(address.encode())
    kyc_proof = self._tree.get_hex_proof(leaf)
    proof = self._tree.get_proof(leaf)
    proof_ok = self._tree.verify(proof, leaf, self._tree.get_root())
    print(f"KYC proof ok? {proof_ok}")

    return kyc_proof
def updateList(self, address_list: List[str]) ‑> Kyc
Expand source code
def updateList(self, address_list: List[str]) -> "Kyc":
    self.data = address_list
    self._tree = MerkleTree([sha256(leaf.encode()) for leaf in address_list], sha256)
    self.root_hex = self._tree.get_hex_root()
    self.root_layers = self._tree.get_hex_layers()
    return self
class MerkleTree (leaves: List[Any], hash_function: Callable, sort: Union[bool, NoneType] = False)

Merkle Tree Lx.

Expand source code
class MerkleTree:
    """
    Merkle Tree Lx.

    """

    def __init__(
            self, leaves: List[Any], hash_function: Callable, sort: Optional[bool] = False
    ) -> None:
        self.hash_function = self.bufferify_function(hash_function)
        self.leaves = leaves
        self.sort_leaves = sort
        self.sort_pairs = sort
        self._process_leaves()

    @staticmethod
    def to_hex(value: Any) -> str:
        return value.hex()

    @staticmethod
    def bufferify(value: str) -> str:
        if type(value) == bytes:
            return value
        else:
            return value.encode()

    @staticmethod
    def bufferify_function(func: Callable) -> Callable:
        def f(val):
            return MerkleTree.bufferify(func(val))

        return f

    def _process_leaves(self) -> None:
        self.leaves = [self.bufferify(leaf) for leaf in self.leaves]
        if self.sort_leaves:
            self.leaves.sort()
        self.layers = [self.leaves]
        self._create_hashes(self.leaves)

    def _create_hashes(self, nodes: List[Any]) -> None:
        while len(nodes) > 1:
            n = len(nodes)
            layer_index = len(self.layers)
            self.layers.append([])
            for i in range(0, n, 2):
                if n == i + 1 and n % 2 == 1:
                    self.layers[layer_index].append(nodes[i])
                    continue
                left = self.bufferify(nodes[i])
                right = left if i + 1 == n else self.bufferify(nodes[i + 1])
                combined = (
                    left + right
                    if not (self.sort_pairs and right > left)
                    else right + left
                )
                hashed_data = self.hash_function(combined)
                self.layers[layer_index].append(hashed_data)
            nodes = self.layers[layer_index]

    def get_hex_layers(self) -> List[str]:
        return [[self.to_hex(leaf) for leaf in layer] for layer in self.layers]

    def get_root(self) -> any:
        try:
            return self.layers[-1][0]
        except IndexError:
            return []

    def get_hex_root(self) -> any:
        if self.get_root():
            return self.to_hex(self.get_root())
        else:
            return []

    def get_proof(self, leaf: Union[str, bytes], index: Optional[int] = None) -> List[Dict[str, Any]]:
        proof = []
        leaf = self.bufferify(leaf)
        if not index:
            try:
                index = self.leaves.index(leaf)
            except ValueError:
                return []

        for layer in self.layers:
            is_right_node = index % 2
            pair_index = index - 1 if is_right_node else index + 1
            if pair_index < len(layer):
                proof.append(
                    {
                        "position": "left" if is_right_node else "right",
                        "data": layer[pair_index],
                    }
                )
            index = index // 2
        return proof

    def get_hex_proof(self, leaf: Union[str, bytes], index: Optional[int] = 0) -> List[str]:
        return [self.to_hex(item["data"]) for item in self.get_proof(leaf, index)]

    def verify(self, proof: List[Dict[str, str]], target_node: str, root: str) -> bool:
        hash = self.bufferify(target_node)
        root = self.bufferify(root)
        for node in proof:
            data = self.bufferify(node["data"])
            is_left_node = node["position"] == "left"
            if self.sort_pairs:
                combined = data + hash if data <= hash else hash + data
            else:
                combined = data + hash if is_left_node else hash + data
            hash = self.hash_function(combined)
        return hash == root

    def get_depth(self) -> int:
        return len(self.layers) - 1

    def reset_tree(self) -> None:
        self.leaves = []
        self.layers = []

Static methods

def bufferify(value: str) ‑> str
Expand source code
@staticmethod
def bufferify(value: str) -> str:
    if type(value) == bytes:
        return value
    else:
        return value.encode()
def bufferify_function(func: Callable) ‑> Callable
Expand source code
@staticmethod
def bufferify_function(func: Callable) -> Callable:
    def f(val):
        return MerkleTree.bufferify(func(val))

    return f
def to_hex(value: Any) ‑> str
Expand source code
@staticmethod
def to_hex(value: Any) -> str:
    return value.hex()

Methods

def get_depth(self) ‑> int
Expand source code
def get_depth(self) -> int:
    return len(self.layers) - 1
def get_hex_layers(self) ‑> List[str]
Expand source code
def get_hex_layers(self) -> List[str]:
    return [[self.to_hex(leaf) for leaf in layer] for layer in self.layers]
def get_hex_proof(self, leaf: Union[str, bytes], index: Union[int, NoneType] = 0) ‑> List[str]
Expand source code
def get_hex_proof(self, leaf: Union[str, bytes], index: Optional[int] = 0) -> List[str]:
    return [self.to_hex(item["data"]) for item in self.get_proof(leaf, index)]
def get_hex_root(self) ‑> 
Expand source code
def get_hex_root(self) -> any:
    if self.get_root():
        return self.to_hex(self.get_root())
    else:
        return []
def get_proof(self, leaf: Union[str, bytes], index: Union[int, NoneType] = None) ‑> List[Dict[str, Any]]
Expand source code
def get_proof(self, leaf: Union[str, bytes], index: Optional[int] = None) -> List[Dict[str, Any]]:
    proof = []
    leaf = self.bufferify(leaf)
    if not index:
        try:
            index = self.leaves.index(leaf)
        except ValueError:
            return []

    for layer in self.layers:
        is_right_node = index % 2
        pair_index = index - 1 if is_right_node else index + 1
        if pair_index < len(layer):
            proof.append(
                {
                    "position": "left" if is_right_node else "right",
                    "data": layer[pair_index],
                }
            )
        index = index // 2
    return proof
def get_root(self) ‑> 
Expand source code
def get_root(self) -> any:
    try:
        return self.layers[-1][0]
    except IndexError:
        return []
def reset_tree(self) ‑> NoneType
Expand source code
def reset_tree(self) -> None:
    self.leaves = []
    self.layers = []
def verify(self, proof: List[Dict[str, str]], target_node: str, root: str) ‑> bool
Expand source code
def verify(self, proof: List[Dict[str, str]], target_node: str, root: str) -> bool:
    hash = self.bufferify(target_node)
    root = self.bufferify(root)
    for node in proof:
        data = self.bufferify(node["data"])
        is_left_node = node["position"] == "left"
        if self.sort_pairs:
            combined = data + hash if data <= hash else hash + data
        else:
            combined = data + hash if is_left_node else hash + data
        hash = self.hash_function(combined)
    return hash == root