import _ from "lodash";

import { EmbeddingAttribute } from "./attributes/EmbeddingAttribute";
import { Attribute } from "./attributes/Attribute";
import { PositionAttribute } from "./attributes/PositionAttribute";
import { DatasetObject } from "./types";
import { NumberAttribute } from "./attributes/NumberAttribute";
import { TextAttribute } from "./attributes/TextAttribute";
import { ImageAttribute } from "./attributes/ImageAttribute";
import { CategoryAttribute } from "./attributes/CategoryAttribute";
import { transpose } from "./utils";
import { TextIndex } from "./indices/TextIndex";

const AttributeTypes = {
  attribute: Attribute,
  number: NumberAttribute,
  category: CategoryAttribute,
  text: TextAttribute,
  position: PositionAttribute,
  embedding: EmbeddingAttribute,
  image: ImageAttribute,
};

export class Dataset {
  count = 0;

  attributes: Attribute[] = [];

  embeddings?: EmbeddingAttribute;
  projections?: PositionAttribute;

  textIndex: TextIndex;

  static fromObject(obj: DatasetObject) {
    const attributes = obj.attributes.map((a) => {
      const AttributeType = AttributeTypes[a.type || "number"];
      if (!AttributeType) {
        throw new Error(`Unknown attribute type ${a.type}`);
      }

      // Check if it's a text attribute with category type
      if (a.type === "category" && a.categories.length > a.values.length / 2) {
        console.warn("Attribute " + a.name + " has many values.");
        return TextAttribute.fromCategories(a.name, a.values, a.categories);
      }

      return AttributeType.fromObject(a);
    });

    const dataset = new Dataset(attributes);

    if (obj.embeddings) {
      dataset.setEmbeddings(obj.embeddings.values, obj.embeddings.dimensions);
    }

    const projections = obj.projections || dataset.computeProjections();
    if (projections) {
      dataset.setProjections(projections);
    }

    return dataset;
  }

  constructor(attributes: Attribute[] = []) {
    attributes.forEach((attribute) => this.addAttribute(attribute));
    this.textIndex = new TextIndex(this);
  }

  async buildIndices() {
    await this.textIndex?.build();
  }

  ensureCount(count: number) {
    if (this.count === 0) {
      this.count = count;
      return;
    }

    if (this.count !== count) {
      throw new Error(
        `Invalid count ${count} for a dataset with count ${this.count}`
      );
    }
  }

  // Attributes
  addAttribute(attribute: Attribute) {
    this.ensureCount(this.count);
    this.attributes.push(attribute);
  }

  getAttributeByName(name: string): Attribute | undefined {
    return this.attributes.find((a) => a.name === name);
  }

  getAttributeNames(): string[] {
    return this.attributes.map((attribute) => attribute.name);
  }

  // Embeddings
  setEmbeddings(values: Float32Array, dimensions: number) {
    this.ensureCount(values.length / dimensions);
    this.embeddings = new EmbeddingAttribute(values, dimensions);
  }

  // Projections
  setProjections(values: Float32Array) {
    this.ensureCount(values.length / 3);
    this.projections = new PositionAttribute(values);
  }

  computeProjections() {
    if (this.embeddings) {
      return this.embeddings.project();
    }

    // Set projections from x, y, z attributes
    const x = this.attributes.find((a) => a.name === "x");
    const y = this.attributes.find((a) => a.name === "y");
    const z = this.attributes.find((a) => a.name === "z");

    if (
      x instanceof NumberAttribute &&
      y instanceof NumberAttribute &&
      z instanceof NumberAttribute
    ) {
      return transpose(x.values, y.values, z.values);
    }

    if (x instanceof NumberAttribute && y instanceof NumberAttribute) {
      return transpose(x.values, y.values, new Float32Array(x.count));
    }
  }

  // Values
  getDisplayValues(index: number): Map<string, string> {
    const values = new Map<string, string>();
    this.attributes.forEach((attribute) => {
      values.set(attribute.name, attribute.getDisplayValue(index));
    });
    return values;
  }

  toObject(): DatasetObject {
    return {
      attributes: this.attributes.map((attribute) => attribute.toObject()),
      embeddings: this.embeddings?.toObject(),
      projections: this.projections?.values,
    };
  }
}
