import * as THREE from "three";

export interface BoxTreeOptions {
  maxDepth: number;
  maxElements: number;
  octree: boolean;
}

export interface BoxTreeNode {
  id: number;
  offset: number;
  count: number;
  depth: number;
  index: Uint32Array;
  boundingBox: THREE.Box3;
  children: BoxTreeNode[];
}

const _size = new THREE.Vector3();

export class BoxTree {
  root: BoxTreeNode;
  nodes: Map<number, BoxTreeNode>;

  static defaultOptions: BoxTreeOptions = {
    maxDepth: 8,
    maxElements: 8000,
    octree: false,
  };

  constructor(root: BoxTreeNode, nodes?: Map<number, BoxTreeNode>) {
    this.root = root;
    this.nodes = nodes || this.buildNodeMap(root);
  }

  buildNodeMap(root: BoxTreeNode) {
    const nodes = new Map<number, BoxTreeNode>();
    const addNode = (node: BoxTreeNode) => {
      nodes.set(node.id, node);
      node.children.forEach((child) => addNode(child));
    };
    addNode(root);
    return nodes;
  }

  createNode(
    parent: BoxTreeNode,
    offset: number, // global offset
    count: number
  ): BoxTreeNode {
    const node: BoxTreeNode = {
      id: this.nodes.size,
      offset,
      count,
      depth: parent.depth + 1,
      index: this.root.index.subarray(offset, offset + count),
      boundingBox: parent.boundingBox.clone(),
      children: [],
    };
    this.nodes.set(node.id, node);
    return node;
  }

  static empty() {
    return new this({
      id: 0,
      offset: 0,
      count: 0,
      depth: 0,
      children: [],
      index: new Uint32Array(0),
      boundingBox: new THREE.Box3(),
    });
  }

  static fromPositions(
    positions: Float32Array,
    options?: Partial<BoxTreeOptions>
  ) {
    const tree = BoxTree.empty();
    tree.setFromPositions(positions, options);
    return tree;
  }

  setFromPositions(positions: Float32Array, options?: Partial<BoxTreeOptions>) {
    if (positions.length === 0) {
      return this;
    }

    console.time("Build tree");

    // Create index for positions

    const count = Math.ceil(positions.length / 3);
    const index = new Uint32Array(count);

    for (let i = 0; i < count; i++) {
      index[i] = i;
    }

    this.setFromIndex(index, positions, options);

    console.timeEnd("Build tree");
    console.info("box tree created with", this.nodes.size, "nodes.");

    return this;
  }

  setFromIndex(
    index: Uint32Array,
    positions: Float32Array,
    options?: Partial<BoxTreeOptions>
  ) {
    const count = index.length;

    const position = new THREE.Vector3();

    this.root.count = count;
    this.root.index = index;

    // Root bounding box
    this.root.boundingBox.makeEmpty();
    for (let i = 0; i < this.root.count; i++) {
      position.fromArray(positions, 3 * this.root.index[i]);
      this.root.boundingBox.expandByPoint(position);
    }

    if (this.root.count > 0) {
      this.splitNode(this.root, index, positions, options);
    }

    return this;
  }

  splitNode(
    node: BoxTreeNode,
    rootIndex: Uint32Array,
    positions: Float32Array,
    options?: Partial<BoxTreeOptions>
  ) {
    const index = rootIndex;
    const count = node.count;

    const slices = options?.octree ? 8 : 4;
    const maxDepth = options?.maxDepth || BoxTree.defaultOptions.maxDepth;
    const maxElements =
      options?.maxElements || BoxTree.defaultOptions.maxElements;

    const counts = new Uint32Array(8);
    const offsets = new Uint32Array(8);
    const tmpIndex = new Uint32Array(count);
    const tmpKeys = new Uint8Array(count);
    const position = new THREE.Vector3();
    const center = new THREE.Vector3();

    // Recursively split nodes by quadrant
    const split = (node: BoxTreeNode) => {
      tmpIndex.set(index.subarray(node.offset, node.offset + node.count));
      offsets.fill(0);
      counts.fill(0);

      // Compute keys and count elements
      node.boundingBox.getCenter(center);

      for (let i = 0; i < node.count; i++) {
        tmpKeys[i] = positions[tmpIndex[i] * 3] > center.x ? 1 : 0;
        tmpKeys[i] += positions[tmpIndex[i] * 3 + 1] > center.y ? 2 : 0;
        if (options?.octree) {
          tmpKeys[i] += positions[tmpIndex[i] * 3 + 2] > center.z ? 4 : 0;
        }
        counts[tmpKeys[i]]++;
      }

      // Compute key offsets
      let offset = node.offset;
      for (let k = 0; k < slices; k++) {
        offsets[k] = offset;
        offset += counts[k];
      }

      const children = [];
      for (let k = 0; k < slices; k++) {
        const child = this.createNode(node, offsets[k], counts[k]);
        child.boundingBox.makeEmpty();
        children.push(child);
      }

      for (let i = 0; i < node.count; i++) {
        // Sort index
        const key = tmpKeys[i];
        index[offsets[key]++] = tmpIndex[i];

        // Compute bounding box
        position.fromArray(positions, tmpIndex[i] * 3);
        children[key].boundingBox.expandByPoint(position);
      }

      // Recursive split call
      node.children = children;
      for (const child of children) {
        child.boundingBox.getSize(_size);

        if (child.count > maxElements && child.depth < maxDepth) {
          split(child);
        }
      }
    };

    if (count > 0) {
      split(node);
    }
  }

  get index() {
    return this.root.index;
  }

  get count() {
    return this.root.count;
  }

  getIndexForNodes(nodes: BoxTreeNode[]) {
    const count = nodes.reduce((total, node) => total + node.count, 0);
    let offset = 0;
    const index = new Uint32Array(count);
    for (const node of nodes) {
      for (let i = 0; i < node.count; i++) {
        index[offset++] = this.index[node.offset + i];
      }
    }
    return index;
  }

  queryNodes(
    match: (node: BoxTreeNode) => boolean,
    maxDepth = Infinity
  ): BoxTreeNode[] {
    const nodes: BoxTreeNode[] = [];

    const visit = (node: BoxTreeNode) => {
      if (!match(node)) {
        return;
      }

      if (node.children.length === 0 || node.depth >= maxDepth) {
        nodes.push(node);
        return;
      }

      node.children.forEach((child) => visit(child));
    };

    visit(this.root);
    return nodes;
  }

  traverse(
    callback: (node: BoxTreeNode) => boolean | void,
    node: BoxTreeNode = this.root
  ) {
    if (callback(node) === false) return;
    for (let i = 0; i < node.children.length; i++) {
      this.traverse(callback, node.children[i]);
    }
  }

  queryIndex(match: (node: BoxTreeNode) => boolean, maxDepth = Infinity) {
    return this.getIndexForNodes(this.queryNodes(match, maxDepth));
  }

  copy(source: BoxTree) {
    this.root = source.root;
    this.nodes = source.nodes;
    return this;
  }
}

export default BoxTree;
