import { Immutable } from 'immer';
import {
  TraceASTExpr,
  TraceASTInput,
  TraceASTReference,
  TraceASTType,
} from '../trace';
import assertNever from '@watershed/shared-util/assertNever';
import { UnexpectedError } from '@watershed/errors/UnexpectedError';

// DEAR AUTHOR:
// This file houses utility functions that do things with TraceASTs. They must
// be generic to any context that uses a TraceAST: if it's not generic, put it
// somewhere else.

/**
 * Returns an iterator that yields every node in the TraceAST, including the
 * top-level one that it was fed. Traversal is deterministic, depth-first,
 * pre-order.
 */
export function* yieldSubtraces(
  trace: Immutable<TraceASTExpr>
): Iterable<TraceASTExpr> {
  yield trace;

  const { t } = trace;
  switch (t) {
    case TraceASTType.TNumber:
    case TraceASTType.TInput:
    case TraceASTType.TReference:
      return;

    case TraceASTType.TUnary:
    case TraceASTType.TAnnotation:
    case TraceASTType.TConversion:
      yield* yieldSubtraces(trace.e);
      return;

    case TraceASTType.TBinary:
      yield* yieldSubtraces(trace.l);
      yield* yieldSubtraces(trace.r);
      return;

    default:
      assertNever(t);
  }
}

/**
 * Returns the first subtrace that is not an annotation (which might be the
 * whole input trace). A trace cannot with *only* nested annotations would be
 * infinite, so we'll stop after `maxDepth` is reached and complain.
 */
export function getFirstNonAnnotationSubtrace(
  trace: Immutable<TraceASTExpr>,
  maxDepth = 50
): Immutable<TraceASTExpr> {
  let depth = 0;
  for (const subtrace of yieldSubtraces(trace)) {
    depth += 1;
    if (subtrace.t !== TraceASTType.TAnnotation) {
      return subtrace;
    }
    if (depth >= maxDepth) {
      break;
    }
  }
  throw new UnexpectedError(
    `Did not find any non-annotation subtraces after checking ${maxDepth} subtraces`
  );
}

/**
 * Returns an iterable that yields all the TReferences in a trace. Traversal is
 * deterministic, depth-first, pre-order.
 */
export function* yieldReferenceSubtraces(
  trace: Immutable<TraceASTExpr>
): Iterable<TraceASTReference> {
  for (const subtrace of yieldSubtraces(trace)) {
    if (subtrace.t === TraceASTType.TReference) {
      yield subtrace;
    }
  }
}

/**
 * Returns an iterable that yields all the TInputs in a trace. Traversal is
 * deterministic, depth-first, pre-order.
 */
export function* yieldInputSubtraces(
  trace: Immutable<TraceASTExpr>
): Iterable<TraceASTInput> {
  for (const subtrace of yieldSubtraces(trace)) {
    if (subtrace.t === TraceASTType.TInput) {
      yield subtrace;
    }
  }
}
