c
c     mxn matrix vector product y = y + Ax
c
c     Compares the performance of a FORALL
c     style matvec with a PDO style matvec
c
      program main
      include 'mv.h'

      real a(m,n),b(m,n),x(n),y(m)
      real result(m)
      integer j
      integer pdo_clocks, forall_clocks
      external timer_stop
      integer timer_stop
      real left,right,real

c     define the data distributions
      template t(n)
      align a(i,j) with t(j)
      align b(i,j) with t(j)
      align x(j) with t(j)
      align y(j) with t(j)
      distribute t(block(64))

c     initialize the input arrays
      pdo j = 1,n 
      pout a(:,j),x(j)
         a(:,j) = real(j)
         x(j) = 1.0
      endpdo
c
c     forall based mv product
c
      y = 0.0
      call timer_start()

c     forall loop
      pdo j=1,n
      pin a(:,j),x(j)
      pout b(:,j)
         b(:,j) = a(:,j)*x(j)
      endpdo

c     user defined reduction
      pdo j=1,n
      pin b(:,j)
      pmvars result
      pinit
         result = 0.0
      pbody
         result = result + b(:,j)
      pmerge
         result = left(result) + right(result) 
      endpdo
      y = y + result
      forall_clocks = timer_stop()
      call verify(result)
      
c
c     pdo-based mv product
c
      y = 0
      call timer_start()
      pdo j=1,n
      pin a(:,j),x(j)
      pmvars result
      pinit
         result = 0.0
      pbody
         result = result + a(:,j)*x(j)
      pmerge
         result = left(result) + right(result)
      endpdo
      y = y + result
      pdo_clocks = timer_stop()
      call verify(result)
      call prperf(m,n,forall_clocks, pdo_clocks) 
      end

      subroutine verify(result)
      include 'mv.h'
      real result(n)
      integer errors, i
      
      errors = 0
      do i=1,m
         if (result(i) .ne. (n*(n+1))/2) errors = errors + 1
      enddo
      if (errors .gt. 0) print *, "errors=", errors
      end

      subroutine prperf(m, n, forall_clocks, pdo_clocks) 
      integer m, n, forall_clocks, pdo_clocks
      real pdo_msecs, forall_msecs, real
      real flops, pdo_mflops, forall_mflops

      pdo_msecs = real(pdo_clocks)/20000.0
      forall_msecs = real(forall_clocks)/20000.0         
      flops = m*n*2.0
      pdo_mflops = (flops/pdo_msecs)/1000.0
      forall_mflops = (flops/forall_msecs)/1000.0
         print *, m, n, 
     $        pdo_clocks, pdo_msecs, pdo_mflops,
     $        forall_clocks, forall_msecs, forall_mflops
      end
