comparison src/jit-typeinfo.cc @ 15067:df4538e3b50b

ND scalar indexing in JIT. * src/jit-ir.cc (jit_magic_end::jit_magic_end): Use jit_magic_end::context. * src/jit-ir.h (jit_call::jit_call): New overload. (jit_magic_end::context): New class. (jit_magic_end::jit_magic_end): moved to src/jit-ir.cc. * src/jit-typeinfo.cc (octave_jit_paren_scalar): New function. (jit_typeinfo::jit_typeinfo): Generate ND scalar indexing. (jit_typeinfo::gen_subsref): New function. * src/jit-typeinfo.h (jit_typeinfo::gen_subsref): New declaration. * src/pt-jit.cc (jit_convert::visit_index_expression, jit_convert::do_assign): Update resolve call. (jit_convert::resolve): Resolve ND indices. * src/pt-jit.h (jit_convert::resolve): Change function signature.
author Max Brister <max@2bass.com>
date Tue, 31 Jul 2012 11:51:01 -0500
parents bc32288f4a42
children f57d7578c1a6
comparison
equal deleted inserted replaced
15066:6451a584305e 15067:df4538e3b50b
239 double *data = array->fortran_vec (); 239 double *data = array->fortran_vec ();
240 data[index - 1] = value; 240 data[index - 1] = value;
241 241
242 mat->update (); 242 mat->update ();
243 *ret = *mat; 243 *ret = *mat;
244 }
245
246 extern "C" double
247 octave_jit_paren_scalar (jit_matrix *mat, double *indicies,
248 octave_idx_type idx_count)
249 {
250 // FIXME: Replace this with a more optimal version
251 try
252 {
253 Array<idx_vector> idx (dim_vector (1, idx_count));
254 for (octave_idx_type i = 0; i < idx_count; ++i)
255 idx(i) = idx_vector (indicies[i]);
256
257 Array<double> ret = mat->array->index (idx);
258 return ret.xelem (0);
259 }
260 catch (const octave_execution_exception&)
261 {
262 gripe_library_execution_error ();
263 return 0;
264 }
244 } 265 }
245 266
246 extern "C" void 267 extern "C" void
247 octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat, 268 octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat,
248 jit_range *index, double value) 269 jit_range *index, double value)
786 scalar = new_type ("scalar", complex, scalar_t); 807 scalar = new_type ("scalar", complex, scalar_t);
787 range = new_type ("range", any, range_t); 808 range = new_type ("range", any, range_t);
788 string = new_type ("string", any, string_t); 809 string = new_type ("string", any, string_t);
789 boolean = new_type ("bool", any, bool_t); 810 boolean = new_type ("bool", any, bool_t);
790 index = new_type ("index", any, index_t); 811 index = new_type ("index", any, index_t);
812
813 // a fake type for interfacing with C++
814 jit_type *scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ());
791 815
792 create_int (8); 816 create_int (8);
793 create_int (16); 817 create_int (16);
794 create_int (32); 818 create_int (32);
795 create_int (64); 819 create_int (64);
1308 merge->addIncoming (ret, success); 1332 merge->addIncoming (ret, success);
1309 fn.do_return (builder, merge); 1333 fn.do_return (builder, merge);
1310 } 1334 }
1311 paren_subsref_fn.add_overload (fn); 1335 paren_subsref_fn.add_overload (fn);
1312 1336
1337 // generate () subsref for ND indexing of matricies with scalars
1338 jit_function paren_scalar = create_function (jit_convention::external,
1339 "octave_jit_paren_scalar",
1340 scalar, matrix, scalar_ptr,
1341 index);
1342 paren_scalar.add_mapping (engine, &octave_jit_paren_scalar);
1343 paren_scalar.mark_can_error ();
1344
1345 // FIXME: Generate this on the fly
1346 for (size_t i = 2; i < 10; ++i)
1347 gen_subsref (paren_scalar, i);
1348
1313 // paren subsasgn 1349 // paren subsasgn
1314 paren_subsasgn_fn.stash_name ("()subsasgn"); 1350 paren_subsasgn_fn.stash_name ("()subsasgn");
1315 1351
1316 jit_function resize_paren_subsasgn 1352 jit_function resize_paren_subsasgn
1317 = create_function (jit_convention::external, 1353 = create_function (jit_convention::external,
1829 jit_type *ret = new jit_type (name, parent, llvm_type, next_id++); 1865 jit_type *ret = new jit_type (name, parent, llvm_type, next_id++);
1830 id_to_type.push_back (ret); 1866 id_to_type.push_back (ret);
1831 return ret; 1867 return ret;
1832 } 1868 }
1833 1869
1870 void
1871 jit_typeinfo::gen_subsref (const jit_function& paren_scalar, size_t n)
1872 {
1873 std::stringstream name;
1874 name << "jit_paren_subsref_matrix_scalar" << n;
1875 std::vector<jit_type *> args (n + 1, scalar);
1876 args[0] = matrix;
1877 jit_function fn = create_function (jit_convention::internal, name.str (),
1878 scalar, args);
1879 fn.mark_can_error ();
1880 llvm::BasicBlock *body = fn.new_block ();
1881 builder.SetInsertPoint (body);
1882
1883 llvm::Type *scalar_t = scalar->to_llvm ();
1884 llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n);
1885 llvm::Value *array = llvm::UndefValue::get (array_t);
1886 for (size_t i = 0; i < n; ++i)
1887 {
1888 llvm::Value *idx = fn.argument (builder, i + 1);
1889 array = builder.CreateInsertValue (array, idx, i);
1890 }
1891
1892 llvm::Value *array_mem = builder.CreateAlloca (array_t);
1893 builder.CreateStore (array, array_mem);
1894 array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ());
1895
1896 llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n);
1897 llvm::Value *mat = fn.argument (builder, 0);
1898 llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem);
1899 fn.do_return (builder, ret);
1900 paren_subsref_fn.add_overload (fn);
1901 }
1902
1834 #endif 1903 #endif