/* Ecole CNC'2
   Nancy 25-26 juin 2009

   probleme 2.4 */

#include <errno.h>
#include <stdlib.h>
#include <stdio.h>
#include <mpfr.h>

int
rec_nroot (mpfr_ptr y, mpfr_srcptr x, unsigned long n, mp_rnd_t rnd)
{
  int inex1, inex2;
  mpfr_t u;
  mp_prec_t prec_y = mpfr_get_prec (y);
  mp_prec_t prec = mpfr_get_prec (y);

  if (mpfr_zero_p (x)
      || (n%2 == 0 && mpfr_cmp_ui (x, 0) < 0))
    {
      mpfr_set_nan (y);
      return 0;
    }

  mpfr_init (u);

  for (;;)
    {
      prec += 8;
      mpfr_set_prec (u, prec);
      inex1 = mpfr_root (u, x, n, GMP_RNDD);
      inex2 = mpfr_ui_div (u, 1, u, GMP_RNDU);
      if ((inex1 == 0 && inex2 == 0)
          || mpfr_can_round (u, prec - 2, GMP_RNDU, GMP_RNDZ,
                             prec_y + (rnd == GMP_RNDN)))
        break;
    }

  inex1 = mpfr_set (y, u, rnd);
  mpfr_clear (u);

  return inex1;
}

int
main (int argc, char **argv)
{
  unsigned long n;
  mpfr_t x, y;
  mpfr_prec_t prec;
  int inex;

  mpfr_init (x);
  mpfr_init (y);

  if (argc != 4)
	{
	  printf ("Usage: %s n prec x_val\n", argv[0]);
	  return 1;
	}

  errno = 0;
  n = strtoul (argv[1], NULL, 10);
  if (errno || n < 2)
	{
	  printf ("Invalid n value\n");
	  return 1;
	}

  errno = 0;
  prec = strtoul (argv[2], NULL, 10);
  if (errno || prec < 2 || prec > MPFR_PREC_MAX)
	{
	  printf ("Invalid precision\n");
	  return 1;
	}

  mpfr_set_prec (x, prec);
  mpfr_set_prec (y, prec);

  if (mpfr_set_str (x, argv[3], 10, GMP_RNDN) == -1)
	{
	  printf ("Can't read floating-point value\n");
	  return 1;
	}

  inex = rec_nroot (y, x, n, GMP_RNDD);
  mpfr_printf ("RNDD: inexact %+d\t%.*Re\n", inex, 692, y);
  inex = rec_nroot (y, x, n, GMP_RNDN);
  mpfr_printf ("RNDN: inexact %+d\t%.*Re\n", inex, 693, y);
  inex = rec_nroot (y, x, n, GMP_RNDU);
  mpfr_printf ("RNDU: inexact %+d\t%.*Re\n", inex, 692, y);

  mpfr_clear (x);
  mpfr_clear (y);
  return 0;
}

