import React, { useEffect, useImperativeHandle, useRef } from "react";
import { invalidate, useThree } from "@react-three/fiber";
import { Box3, Group, Object3D, PerspectiveCamera, Quaternion, Raycaster, Vector2, Vector3 } from "three";
import { useSpring } from "react-spring";

export type ViewRControlsRef = {
  fitCameraToModel: () => void;
};

export type ViewRControlsProps = {
  modelGroupRef: React.RefObject<Object3D>;
  rotateSensitivity?: number;
  panSensitivity?: number;
  zoomSensitivity?: number;
  zoomSlowDownDistance?: number;
};

type CamMoveConfig = {
  type: "rotate",
  start: {
    relPointerPos: Vector2;
    camAngleX: number;
    camAngleY: number;
  };
  rig: Group;
} | {
  type: "stationary";
  camAngleX: number;
  camAngleY: number;
} | {
  type: "panZoom";
  start: {
    relPointerPos: Vector2;
    pointerDist?: number;
    camPos: Vector3;
  },
  camAngleX: number;
  camAngleY: number;
};

const ViewRControls = React.forwardRef<ViewRControlsRef, ViewRControlsProps>(({ modelGroupRef, rotateSensitivity = 2.5, panSensitivity = 25, zoomSensitivity = 0.00001, zoomSlowDownDistance = 5 }, ref) => {
  const { gl, camera, scene, size } = useThree();
  const camMoveConfig = useRef<CamMoveConfig | undefined>();

  useEffect(() => {
    const raycaster = new Raycaster();

    const camStartTilt = -45 * Math.PI / 180;
    const camStartPos = new Vector3(0, 9999, 9999);

    camMoveConfig.current = {
      type: "stationary",
      camAngleX: camStartTilt,
      camAngleY: 0,
    };

    camera.rotation.set(camStartTilt, 0, 0);
    camera.position.copy(camStartPos);
    controlsParentRef.current?.add(camera);

    const getRelativeMousePos = (x1: number, y1: number, x2?: number, y2?: number) => {
      const cr = gl.domElement.getBoundingClientRect();
      if (x2 === undefined || y2 === undefined) {
        return new Vector2((x1 - cr.left) / cr.width * 2 - 1, 1 - (y1 - cr.top) / cr.height * 2);
      } else {
        const avX = ((x1 - cr.left) + (x2 - cr.left)) / 2;
        const avY = ((y1 - cr.top) + (y2 - cr.top)) / 2;
        return new Vector2(avX / cr.width * 2 - 1, 1 - avY / cr.height * 2);
      }
    };

    const getTouchDist = (x1: number, y1: number, x2: number, y2: number) => {
      return Math.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1));
    };

    const startControlsRotation = (relPointerPos: Vector2) => {
      if (camMoveConfig.current?.type === "stationary") {
        raycaster.setFromCamera(relPointerPos, camera);
        const io = raycaster.intersectObjects(modelGroupRef.current?.children ?? [], true);
        const rigY = new Group();
        rigY.position.copy(io[0] ? io[0].point : camera.position);
        rigY.rotation.set(0, camMoveConfig.current.camAngleY, 0);
        controlsParentRef.current?.add(rigY);
        const rigX = new Group();
        rigX.rotation.set(camMoveConfig.current.camAngleX, 0, 0);
        rigY.add(rigX);
        rigX.attach(camera);
        camMoveConfig.current = {
          type: "rotate",
          start: {
            relPointerPos,
            camAngleX: camMoveConfig.current.camAngleX,
            camAngleY: camMoveConfig.current.camAngleY,
          },
          rig: rigY,
        };
      }
    };
    const updateControlsRotation = (relPointerPos: Vector2) => {
      if (camMoveConfig.current?.type === "rotate") {
        const mouseDelta = camMoveConfig.current.start.relPointerPos.clone().sub(relPointerPos);
        camMoveConfig.current.rig.rotation.y = camMoveConfig.current.start.camAngleY + mouseDelta.x * rotateSensitivity;
        camMoveConfig.current.rig.children[0].rotation.x = camMoveConfig.current.start.camAngleX - mouseDelta.y * rotateSensitivity;
      }
    };
    const endControlsRotation = () => {
      if (camMoveConfig.current?.type === "rotate") {
        controlsParentRef.current?.attach(camera);
        controlsParentRef.current?.remove(camMoveConfig.current.rig);
        camMoveConfig.current = {
          type: "stationary",
          camAngleX: camMoveConfig.current.rig.children[0].rotation.x,
          camAngleY: camMoveConfig.current.rig.rotation.y,
        };
      }
    };

    const startControlsPanZoom = (relPointerPos: Vector2, pointerDist?: number) => {
      if (camMoveConfig.current?.type === "stationary") {
        camMoveConfig.current = {
          type: "panZoom",
          start: {
            relPointerPos,
            camPos: camera.position.clone(),
            pointerDist,
          },
          camAngleX: camMoveConfig.current.camAngleX,
          camAngleY: camMoveConfig.current.camAngleY,
        };
      }
    };
    const updateControlsPanZoom = (relPointerPos: Vector2, pointerDist?: number) => {
      if (camMoveConfig.current?.type === "panZoom") {
        const mouseDelta = camMoveConfig.current.start.relPointerPos.clone().sub(relPointerPos);
        const camLeft = new Vector3(1, 0, 0);
        camLeft.applyQuaternion(camera.getWorldQuaternion(new Quaternion()));
        const camUp = new Vector3(0, 1, 0);
        camUp.applyQuaternion(camera.getWorldQuaternion(new Quaternion()));
        const camFwd = new Vector3(0, 0, 1);
        camFwd.applyQuaternion(camera.getWorldQuaternion(new Quaternion()));
        camera.position.copy(camMoveConfig.current.start.camPos.clone().addScaledVector(camLeft, mouseDelta.x * panSensitivity).addScaledVector(camUp, mouseDelta.y * panSensitivity).addScaledVector(camFwd, 0.1 * ((camMoveConfig.current.start.pointerDist && pointerDist) ? camMoveConfig.current.start.pointerDist - pointerDist : 0)));
      }
    };
    const endControlsPanZoom = () => {
      if (camMoveConfig.current?.type === "panZoom") {
        camMoveConfig.current = {
          type: "stationary",
          camAngleX: camMoveConfig.current.camAngleX,
          camAngleY: camMoveConfig.current.camAngleY,
        };
      }
    };

    const controlsZoom = (relPointerPos: Vector2, z: number) => {
      if (camMoveConfig.current?.type === "stationary") {
        const zoomVec = new Vector3(relPointerPos.x, relPointerPos.y, 1);
        zoomVec.unproject(camera).sub(camera.position);
        raycaster.setFromCamera(relPointerPos, camera);
        const io = raycaster.intersectObjects(modelGroupRef.current?.children ?? [], true);
        const camSpeed = zoomSensitivity * ((io[0] !== undefined && io[0].distance <= zoomSlowDownDistance) ? Math.max(io[0].distance / zoomSlowDownDistance, 0.05) : 1);
        camera.position.addScaledVector(zoomVec, -camSpeed * z);
      }
    };

    const pointerDownListener = (e: PointerEvent) => {
      gl.domElement.setPointerCapture(e.pointerId);
    };
    const mouseDownListener = (e: MouseEvent) => {
      const currRelMousePos = getRelativeMousePos(e.clientX, e.clientY);
      if (e.buttons === 1 || e.buttons === 4) {
        startControlsRotation(currRelMousePos);
      }
      if (e.buttons === 2) {
        startControlsPanZoom(currRelMousePos);
      }
    };
    const touchStartListener = (e: TouchEvent) => {
      if (e.touches.length === 1) {
        const currRelTouchPos = getRelativeMousePos(e.touches[0].clientX, e.touches[0].clientY);
        startControlsRotation(currRelTouchPos);
      }
      if (e.touches.length === 2) {
        endControlsRotation();
        const currRelTouchPos = getRelativeMousePos(e.touches[0].clientX, e.touches[0].clientY, e.touches[1].clientX, e.touches[1].clientY);
        startControlsPanZoom(currRelTouchPos, getTouchDist(e.touches[0].clientX, e.touches[0].clientY, e.touches[1].clientX, e.touches[1].clientY));
      }
    };
    const mouseMoveListener = (e: MouseEvent) => {
      const currRelMousePos = getRelativeMousePos(e.clientX, e.clientY);
      updateControlsRotation(currRelMousePos);
      updateControlsPanZoom(currRelMousePos);
    };
    const touchMoveListener = (e: TouchEvent) => {
      if (e.touches.length === 1) {
        const currRelTouchPos = getRelativeMousePos(e.touches[0].clientX, e.touches[0].clientY);
        updateControlsRotation(currRelTouchPos);
      }
      if (e.touches.length === 2) {
        const currRelTouchPos = getRelativeMousePos(e.touches[0].clientX, e.touches[0].clientY, e.touches[1].clientX, e.touches[1].clientY);
        updateControlsPanZoom(currRelTouchPos, getTouchDist(e.touches[0].clientX, e.touches[0].clientY, e.touches[1].clientX, e.touches[1].clientY));
      }
    };
    const mouseUpListener = (e: MouseEvent) => {
      if (!(e.buttons === 1 || e.buttons === 4)) {
        endControlsRotation();
      }
      if (!(e.buttons === 2)) {
        endControlsPanZoom();
      }
    };
    const touchEndListener = (e: TouchEvent) => {
      if (e.touches.length === 1) {
        endControlsPanZoom();
        const currRelTouchPos = getRelativeMousePos(e.touches[0].clientX, e.touches[0].clientY);
        startControlsRotation(currRelTouchPos);
      }
      if (e.touches.length === 0) {
        endControlsPanZoom();
        endControlsRotation();
      }
    };
    const contextMenuListener = (e: Event) => {
      e.preventDefault();
    };
    const wheelListener = (e: WheelEvent) => {
      const currRelMousePos = getRelativeMousePos(e.clientX, e.clientY);
      controlsZoom(currRelMousePos, e.deltaY);
    };
    gl.domElement.addEventListener("contextmenu", contextMenuListener);

    gl.domElement.addEventListener("pointerdown", pointerDownListener, { passive: true });
    gl.domElement.addEventListener("mousedown", mouseDownListener, { passive: true });
    gl.domElement.addEventListener("touchstart", touchStartListener, { passive: true });

    gl.domElement.addEventListener("mousemove", mouseMoveListener, { passive: true });
    gl.domElement.addEventListener("touchmove", touchMoveListener, { passive: true });

    gl.domElement.addEventListener("mouseup", mouseUpListener, { passive: true });
    gl.domElement.addEventListener("touchend", touchEndListener, { passive: true });

    gl.domElement.addEventListener("wheel", wheelListener, { passive: true });

    let keepAliveFrameLoop = true;
    const prevCameraPos = camera.getWorldPosition(new Vector3());
    const prevCameraDir = camera.getWorldDirection(new Vector3());
    const onFrame = () => {
      const currCameraPos = camera.getWorldPosition(new Vector3());
      const currCameraDir = camera.getWorldDirection(new Vector3());
      if (
        currCameraPos.distanceTo(prevCameraPos) > 0.00001 ||
        currCameraDir.distanceTo(prevCameraDir) > 0.00001
      ) {
        prevCameraPos.copy(currCameraPos);
        prevCameraDir.copy(currCameraDir);
        invalidate();
      }
      if (keepAliveFrameLoop) {
        requestAnimationFrame(onFrame);
      }
    };
    onFrame();

    return () => {
      gl.domElement.removeEventListener("contextmenu", contextMenuListener);
      gl.domElement.removeEventListener("pointerdown", pointerDownListener);
      gl.domElement.removeEventListener("mousedown", mouseDownListener);
      gl.domElement.removeEventListener("mouseup", mouseUpListener);
      gl.domElement.removeEventListener("mousemove", mouseMoveListener);
      gl.domElement.removeEventListener("wheel", wheelListener);
      scene.remove(camera);
      keepAliveFrameLoop = false;
    };
  }, [gl, camera, scene, modelGroupRef, rotateSensitivity, panSensitivity, zoomSensitivity, zoomSlowDownDistance]);

  const controlsParentRef = useRef<Object3D>(null);

  const [, springApi] = useSpring<{ x: number; y: number; z: number }>(() => ({
    x: 0,
    y: 0,
    z: 0,
    onChange: ({ value: { x, y, z } }) => {
      camera.position.set(x, y, z);
    },
  }));

  useImperativeHandle(ref, () => ({
    fitCameraToModel: (fitFac = 1) => {
      if (!modelGroupRef.current) return;
      const bbox = new Box3().setFromObject(modelGroupRef.current);
      const diameter = bbox.min.distanceTo(bbox.max);
      const bboxCenter = bbox.min.clone().addScaledVector(bbox.max.clone().sub(bbox.min), 0.5);

      const camLookAt = new Vector3(0, 0, 1);
      camLookAt.applyQuaternion(camera.getWorldQuaternion(new Quaternion()));

      const fitHeightDist = diameter / 2 / Math.tan(Math.PI * (camera as PerspectiveCamera).fov / 360);
      const viewportAspect = size.width / size.height;
      const targetPos = bboxCenter.clone().addScaledVector(camLookAt, fitFac * (viewportAspect > 1 ? fitHeightDist : fitHeightDist / viewportAspect));
      springApi.set({ x: camera.position.x, y: camera.position.y, z: camera.position.z });
      springApi.start({ x: targetPos.x, y: targetPos.y, z: targetPos.z });
    },
  }));

  return (
    <group ref={controlsParentRef}>
    </group>
  );
});

export default ViewRControls;
