view liboctave/operators/Sparse-perm-op-defs.h @ 33248:7f73e4805a1f

replace calls to assert with liboctave_panic functions liboctave * lo-error.h (liboctave_panic_impossible, liboctave_panic_if, liboctave_panic_unless): New macros. Replace all calls to assert with panic_if, panic_impossible, or panic_unless as appropriate. Include "lo-error.h" as needed. Don't include <cassert>. Affected files: Array-base.cc, Array-util.cc, Array.h, CSparse.cc, DiagArray2.cc, 1 DiagArray2.h, Sparse.cc, Sparse.h, dim-vector.h, idx-vector.cc, idx-vector.h, CollocWt.cc, Quad.cc, oct-rand.cc, qrp.cc, svd.cc, Sparse-perm-op-defs.h, and oct-sort.cc.
author John W. Eaton <jwe@octave.org>
date Sun, 24 Mar 2024 18:12:34 -0400
parents 2e484f9f1f18
children
line wrap: on
line source

////////////////////////////////////////////////////////////////////////
//
// Copyright (C) 2009-2024 The Octave Project Developers
//
// See the file COPYRIGHT.md in the top-level directory of this
// distribution or <https://octave.org/copyright/>.
//
// This file is part of Octave.
//
// Octave is free software: you can redistribute it and/or modify it
// under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Octave is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Octave; see the file COPYING.  If not, see
// <https://www.gnu.org/licenses/>.
//
////////////////////////////////////////////////////////////////////////

#if ! defined (octave_Sparse_perm_op_defs_h)
#define octave_Sparse_perm_op_defs_h 1

#include "octave-config.h"

#include "PermMatrix.h"
#include "lo-array-errwarn.h"
#include "oct-locbuf.h"
#include "oct-sort.h"
#include "quit.h"

// Matrix multiplication

template <typename SM>
SM octinternal_do_mul_colpm_sm (const octave_idx_type *pcol, const SM& a)
// Relabel the rows according to pcol.
{
  const octave_idx_type nr = a.rows ();
  const octave_idx_type nc = a.cols ();
  const octave_idx_type nent = a.nnz ();
  SM r (nr, nc, nent);

  octave_sort<octave_idx_type> sort;

  for (octave_idx_type j = 0; j <= nc; ++j)
    r.xcidx (j) = a.cidx (j);

  for (octave_idx_type j = 0; j < nc; j++)
    {
      octave_quit ();

      OCTAVE_LOCAL_BUFFER (octave_idx_type, sidx, r.xcidx (j+1) - r.xcidx (j));
      for (octave_idx_type i = r.xcidx (j), ii = 0; i < r.xcidx (j+1); i++)
        {
          sidx[ii++]=i;
          r.xridx (i) = pcol[a.ridx (i)];
        }
      sort.sort (r.xridx () + r.xcidx (j), sidx, r.xcidx (j+1) - r.xcidx (j));
      for (octave_idx_type i = r.xcidx (j), ii = 0; i < r.xcidx (j+1); i++)
        r.xdata (i) = a.data (sidx[ii++]);
    }

  return r;
}

template <typename SM>
SM octinternal_do_mul_pm_sm (const PermMatrix& p, const SM& a)
{
  const octave_idx_type nr = a.rows ();
  if (p.cols () != nr)
    octave::err_nonconformant ("operator *",
                               p.rows (), p.cols (), a.rows (), a.cols ());

  return octinternal_do_mul_colpm_sm (p.col_perm_vec ().data (), a);
}

template <typename SM>
SM octinternal_do_mul_sm_rowpm (const SM& a, const octave_idx_type *prow)
// For a row permutation, iterate across the source a and stuff the
// results into the correct destination column in r.
{
  const octave_idx_type nr = a.rows ();
  const octave_idx_type nc = a.cols ();
  const octave_idx_type nent = a.nnz ();
  SM r (nr, nc, nent);

  for (octave_idx_type j_src = 0; j_src < nc; ++j_src)
    r.xcidx (prow[j_src]) = a.cidx (j_src+1) - a.cidx (j_src);
  octave_idx_type k = 0;
  for (octave_idx_type j = 0; j < nc; ++j)
    {
      const octave_idx_type tmp = r.xcidx (j);
      r.xcidx (j) = k;
      k += tmp;
    }
  r.xcidx (nc) = nent;

  octave_idx_type k_src = 0;
  for (octave_idx_type j_src = 0; j_src < nc; ++j_src)
    {
      octave_quit ();
      const octave_idx_type j = prow[j_src];
      const octave_idx_type kend_src = a.cidx (j_src + 1);
      for (k = r.xcidx (j); k_src < kend_src; ++k, ++k_src)
        {
          r.xridx (k) = a.ridx (k_src);
          r.xdata (k) = a.data (k_src);
        }
    }
  liboctave_panic_unless (k_src == nent);

  return r;
}

template <typename SM>
SM octinternal_do_mul_sm_colpm (const SM& a, const octave_idx_type *pcol)
// For a column permutation, iterate across the destination r and pull
// data from the correct column of a.
{
  const octave_idx_type nr = a.rows ();
  const octave_idx_type nc = a.cols ();
  const octave_idx_type nent = a.nnz ();
  SM r (nr, nc, nent);

  for (octave_idx_type j = 0; j < nc; ++j)
    {
      const octave_idx_type j_src = pcol[j];
      r.xcidx (j+1) = r.xcidx (j) + (a.cidx (j_src+1) - a.cidx (j_src));
    }
  liboctave_panic_unless (r.xcidx (nc) == nent);

  octave_idx_type k = 0;
  for (octave_idx_type j = 0; j < nc; ++j)
    {
      octave_quit ();
      const octave_idx_type j_src = pcol[j];
      octave_idx_type k_src;
      const octave_idx_type kend_src = a.cidx (j_src + 1);
      for (k_src = a.cidx (j_src); k_src < kend_src; ++k_src, ++k)
        {
          r.xridx (k) = a.ridx (k_src);
          r.xdata (k) = a.data (k_src);
        }
    }
  liboctave_panic_unless (k == nent);

  return r;
}

template <typename SM>
SM octinternal_do_mul_sm_pm (const SM& a, const PermMatrix& p)
{
  const octave_idx_type nc = a.cols ();
  if (p.rows () != nc)
    octave::err_nonconformant ("operator *",
                               a.rows (), a.cols (), p.rows (), p.cols ());

  return octinternal_do_mul_sm_colpm (a, p.col_perm_vec ().data ());
}

#endif