import * as React from "react";
import {
  Table,
  TableBody,
  TableCell,
  TableHead,
  TableHeader,
  TableRow,
} from "@/components/ui/table";
import { LoadingState } from "@/components/shared/loading-state";
import { ErrorState } from "@/components/shared/error-state";
import { EmptyState } from "@/components/shared/empty-state";
import { ArrowDown, ArrowUp, ArrowUpDown } from "lucide-react";

type SortDirection = "asc" | "desc";

interface ColumnDef<TData> {
  id: string;
  header: React.ReactNode;
  cell: (row: TData) => React.ReactNode;
  sortable?: boolean;
  className?: string;
}

interface DataTableProps<TData> {
  columns: ColumnDef<TData>[];
  data: TData[];
  isLoading?: boolean;
  isError?: boolean;
  error?: unknown;
  errorMessage?: string;
  emptyMessage?: string;
  onRowClick?: (row: TData) => void;
  rowKey: (row: TData) => string;
  pagination?: {
    page: number;
    totalPages: number;
    onPageChange: (page: number) => void;
  };
  sortBy?: string;
  sortOrder?: SortDirection;
  onSort?: (columnId: string) => void;
}

function SortIcon({
  columnId,
  sortBy,
  sortOrder,
}: {
  columnId: string;
  sortBy: string | undefined;
  sortOrder: SortDirection | undefined;
}) {
  if (sortBy !== columnId)
    return <ArrowUpDown className="text-muted-foreground/40 ml-1.5 h-3.5 w-3.5" />;
  return sortOrder === "asc" ? (
    <ArrowUp className="text-primary ml-1.5 h-3.5 w-3.5" />
  ) : (
    <ArrowDown className="text-primary ml-1.5 h-3.5 w-3.5" />
  );
}

export function DataTable<TData>({
  columns,
  data,
  isLoading,
  isError,
  error,
  errorMessage,
  emptyMessage = "No data found",
  onRowClick,
  rowKey,
  pagination,
  sortBy,
  sortOrder,
  onSort,
}: DataTableProps<TData>) {
  if (isLoading) {
    return <LoadingState />;
  }

  if (isError) {
    const message =
      errorMessage ??
      (typeof error === "object" && error !== null && "message" in error
        ? (error as any).message
        : "An error occurred");
    return <ErrorState message={message} />;
  }

  if (!data.length) {
    return <EmptyState title={emptyMessage} />;
  }

  return (
    <div className="space-y-4">
      <div className="relative w-full overflow-x-auto">
        <Table>
          <TableHeader>
            <TableRow>
              {columns.map((col) => (
                <TableHead key={col.id} className={col.className}>
                  {col.sortable && onSort ? (
                    <button
                      onClick={() => onSort(col.id)}
                      className="hover:text-foreground flex items-center gap-0 font-medium transition-colors"
                      aria-label={`Sort by ${col.id}`}
                    >
                      {col.header}
                      <SortIcon columnId={col.id} sortBy={sortBy} sortOrder={sortOrder} />
                    </button>
                  ) : (
                    col.header
                  )}
                </TableHead>
              ))}
            </TableRow>
          </TableHeader>
          <TableBody>
            {data.map((row) => (
              <TableRow
                key={rowKey(row)}
                className={onRowClick ? "cursor-pointer" : undefined}
                onClick={onRowClick ? () => onRowClick(row) : undefined}
              >
                {columns.map((col) => (
                  <TableCell key={col.id} className={col.className}>
                    {col.cell(row)}
                  </TableCell>
                ))}
              </TableRow>
            ))}
          </TableBody>
        </Table>
      </div>

      {pagination && pagination.totalPages > 1 && (
        <div className="flex items-center justify-end space-x-2">
          <button
            onClick={() => pagination.onPageChange(pagination.page - 1)}
            disabled={pagination.page <= 1}
            className="rounded border px-3 py-1 disabled:opacity-50"
          >
            Previous
          </button>
          <span className="text-muted-foreground text-sm">
            Page {pagination.page} of {pagination.totalPages}
          </span>
          <button
            onClick={() => pagination.onPageChange(pagination.page + 1)}
            disabled={pagination.page >= pagination.totalPages}
            className="rounded border px-3 py-1 disabled:opacity-50"
          >
            Next
          </button>
        </div>
      )}
    </div>
  );
}
