ROSE  0.11.31.0
Database.h
1 // WARNING: Changes to this file must be contributed back to Sawyer or else they will
2 // be clobbered by the next update from Sawyer. The Sawyer repository is at
3 // https://github.com/matzke1/sawyer.
4 
5 
6 
7 
8 #ifndef Sawyer_Database_H
9 #define Sawyer_Database_H
10 
11 #if __cplusplus >= 201103L
12 
13 #include <boost/iterator/iterator_facade.hpp>
14 #include <boost/lexical_cast.hpp>
15 #include <boost/numeric/conversion/cast.hpp>
16 #include <memory.h>
17 #include <Sawyer/Assert.h>
18 #include <Sawyer/Map.h>
19 #include <Sawyer/Optional.h>
20 #include <string>
21 #include <vector>
22 
23 namespace Sawyer {
24 
167 namespace Database {
168 
169 class Connection;
170 class Statement;
171 class Row;
172 class Iterator;
173 
174 namespace Detail {
175  class ConnectionBase;
176  class StatementBase;
177 }
178 
179 class Exception: public std::runtime_error {
180 public:
181  Exception(const std::string &what)
182  : std::runtime_error(what) {}
183 
184  ~Exception() noexcept {}
185 };
186 
188 // Connection
190 
196 class Connection {
197  friend class ::Sawyer::Database::Statement;
198  friend class ::Sawyer::Database::Detail::ConnectionBase;
199 
200  std::shared_ptr<Detail::ConnectionBase> pimpl_;
201 
202 public:
204  Connection() {};
205 
206 private:
207  explicit Connection(const std::shared_ptr<Detail::ConnectionBase> &pimpl);
208 
209 public:
213  ~Connection() = default;
214 
216  bool isOpen() const;
217 
223  Connection& close();
224 
239  Statement stmt(const std::string &sql);
240 
242  Connection& run(const std::string &sql);
243 
249  template<typename T>
250  Optional<T> get(const std::string &sql);
251 
255  std::string driverName() const;
256 
257  // Undocumented: Row number for the last SQL "insert" (do not use).
258  //
259  // This method is available only if the underlying database driver supports it and it has lots of caveats. In other words,
260  // don't use this method. The most portable way to identify the rows that were just inserted is to insert a UUID as part of
261  // the data.
262  size_t lastInsert() const;
263 
264  // Set the pointer to implementation
265  void pimpl(const std::shared_ptr<Detail::ConnectionBase> &p) {
266  pimpl_ = p;
267  }
268 };
269 
271 // Statement
273 
278 class Statement {
279  friend class ::Sawyer::Database::Detail::ConnectionBase;
280 
281  std::shared_ptr<Detail::StatementBase> pimpl_;
282 
283 public:
285  enum State {
286  UNBOUND,
287  READY,
288  EXECUTING,
289  FINISHED,
290  DEAD
291  };
292 
293 public:
295  Statement() {}
296 
297 private:
298  explicit Statement(const std::shared_ptr<Detail::StatementBase> &stmt)
299  : pimpl_(stmt) {}
300 
301 public:
303  Connection connection() const;
304 
314  template<typename T>
315  Statement& bind(const std::string &name, const T &value);
316 
321  template<typename T>
322  Statement& rebind(const std::string &name, const T &value);
323 
330  Iterator begin();
331 
333  Iterator end();
334 
339  Statement& run();
340 
345  template<typename T>
346  Optional<T> get();
347 };
348 
350 // Row
352 
357 class Row {
358  friend class ::Sawyer::Database::Iterator;
359 
360  std::shared_ptr<Detail::StatementBase> stmt_;
361  size_t sequence_; // for checking validity
362 
363 private:
364  Row()
365  : sequence_(0) {}
366 
367  explicit Row(const std::shared_ptr<Detail::StatementBase> &stmt);
368 
369 public:
371  template<typename T>
372  Optional<T> get(size_t columnIdx) const;
373 
377  size_t rowNumber() const;
378 };
379 
381 // Iterator
383 
390 class Iterator: public boost::iterator_facade<Iterator, const Row, boost::forward_traversal_tag> {
391  friend class ::Sawyer::Database::Detail::StatementBase;
392 
393  Row row_;
394 
395 public:
397  Iterator() {}
398 
399 private:
400  explicit Iterator(const std::shared_ptr<Detail::StatementBase> &stmt);
401 
402 public:
404  bool isEnd() const {
405  return !row_.stmt_;
406  }
407 
409  explicit operator bool() const {
410  return !isEnd();
411  }
412 
413 private:
414  friend class boost::iterator_core_access;
415  const Row& dereference() const;
416  bool equal(const Iterator&) const;
417  void increment();
418 };
419 
420 
424 //
425 // Only implementation details beyond this point.
426 //
430 
431 
432 namespace Detail {
433 
434 // Base class for connection details. The individual drivers (SQLite3, PostgreSQL) will be derived from this class.
435 //
436 // Connection detail objects are reference counted. References come from only two places:
437 // 1. Each top-level Connection object that's in a connected state has a reference to this connection.
438 // 2. Each low-level Statement object that's in an "executing" state has a reference to this connection.
439 // Additionally, all low-level statement objects have a weak reference to a connection.
440 //
441 class ConnectionBase: public std::enable_shared_from_this<ConnectionBase> {
442  friend class ::Sawyer::Database::Connection;
443 
444 protected:
445  ConnectionBase() {}
446 
447 public:
448  virtual ~ConnectionBase() {}
449 
450 protected:
451  // Close any low-level connection.
452  virtual void close() = 0;
453 
454  // Create a prepared statement from the specified high-level SQL. By "high-level" we mean the binding syntax used by this
455  // API such as "?name" (whereas low-level means the syntax passed to the driver such as "?").
456  virtual Statement prepareStatement(const std::string &sql) = 0;
457 
458  // Row number for the last inserted row if supported by this driver. It's better to use a table column that holds a value
459  // generated from a sequence.
460  virtual size_t lastInsert() const = 0;
461 
462  Statement makeStatement(const std::shared_ptr<Detail::StatementBase> &detail);
463 
464  virtual std::string driverName() const = 0;
465 };
466 
467 // Describes the location of "?name" parameters in high-level SQL by associating them with one or more "?" parameters in
468 // low-level SQL. WARNIN: the low-level parameters are numbered starting at one instead of zero, which is inconsistent with how
469 // the low-level APIs index other things like query result columns (not to mention being surprising for C and C++ developers).
470 class Parameter {
471  friend class ::Sawyer::Database::Detail::StatementBase;
472 
473  std::vector<size_t> indexes; // "?" indexes
474  bool isBound = false;
475 
476  void append(size_t idx) {
477  indexes.push_back(idx);
478  }
479 };
480 
481 template<typename T>
482 class ColumnReader {
483  friend class ::Sawyer::Database::Detail::StatementBase;
484  Optional<T> operator()(StatementBase *stmt, size_t idx);
485 };
486 
487 //template<>
488 //class ColumnReader<std::vector<uint8_t>> {
489 // friend class ::Sawyer::Database::Detail::StatementBase;
490 // Optional<std::vector<uint8_t>> operator()(StatementBase *stmt, size_t idx);
491 //};
492 
493 // Reference counted prepared statement details. Objects of this class are referenced from the high-level Statement objects and
494 // the query iterator rows. This class is the base class for driver-specific statements.
495 class StatementBase: public std::enable_shared_from_this<StatementBase> {
496  friend class ::Sawyer::Database::Iterator;
497  friend class ::Sawyer::Database::Row;
498  friend class ::Sawyer::Database::Statement;
499  template<class T> friend class ::Sawyer::Database::Detail::ColumnReader;
500 
501  using Parameters = Container::Map<std::string, Parameter>;
502 
503  std::shared_ptr<ConnectionBase> connection_; // non-null while statement is executing
504  std::weak_ptr<ConnectionBase> weakConnection_; // refers to the originating connection
505  Parameters params_; // mapping from param names to question marks
506  Statement::State state_ = Statement::DEAD; // don't set directly; use "state" member function
507  size_t sequence_ = 0; // sequence number for invalidating row iterators
508  size_t rowNumber_ = 0; // result row number
509 
510 public:
511  virtual ~StatementBase() {}
512 
513 protected:
514  explicit StatementBase(const std::shared_ptr<ConnectionBase> &connection)
515  : weakConnection_(connection) { // save only a weak pointer, no shared pointer
516  ASSERT_not_null(connection);
517  }
518 
519  // Parse the high-level SQL (with "?name" parameters) into low-level SQL (with "?" parameters). Returns the low-level SQL
520  // and the number of low-level "?" parameters and has the following side effects:
521  // 1. Re-initializes this object's parameter list
522  // 2. Sets this object's state to READY, UNBOUND, or DEAD.
523  std::pair<std::string, size_t> parseParameters(const std::string &highSql) {
524  params_.clear();
525  std::string lowSql;
526  bool inString = false;
527  size_t nLowParams = 0;
528  state(Statement::READY); // possibly reset below
529  for (size_t i = 0; i < highSql.size(); ++i) {
530  if ('\'' == highSql[i]) {
531  inString = !inString; // works for "''" escape too
532  lowSql += highSql[i];
533  } else if ('?' == highSql[i] && !inString) {
534  lowSql += '?';
535  std::string paramName;
536  while (i+1 < highSql.size() && (::isalnum(highSql[i+1]) || '_' == highSql[i+1]))
537  paramName += highSql[++i];
538  if (paramName.empty())
539  throw Exception("invalid parameter name at character position " + boost::lexical_cast<std::string>(i));
540  Parameter &param = params_.insertMaybeDefault(paramName);
541  param.append(nLowParams++); // 0-origin low-level parameter numbers
542  state(Statement::UNBOUND);
543  } else {
544  lowSql += highSql[i];
545  }
546  }
547  if (inString) {
548  state(Statement::DEAD);
549  throw Exception("mismatched quotes in SQL statement");
550  }
551  return std::make_pair(lowSql, nLowParams);
552  }
553 
554  // Invalidate all iterators and their rows by incrementing this statements sequence number.
555  void invalidateIteratorsAndRows() {
556  ++sequence_;
557  }
558 
559  // Sequence number used for checking iterator validity.
560  size_t sequence() const {
561  return sequence_;
562  }
563 
564  // Cause this statement to lock the database connection by maintaining a shared pointer to the low-level
565  // connection. Returns true if the connection could be locked, or false if unable.
566  bool lockConnection() {
567  return (connection_ = weakConnection_.lock()) != nullptr;
568  }
569 
570  // Release the connection lock by throwing away the shared pointer to the connection. This statement will still maintain
571  // a weak reference to the connection.
572  void unlockConnection() {
573  connection_.reset();
574  }
575 
576  // Returns an indication of whether this statement holds a lock on the low-level connection, preventing the connection from
577  // being destroyed.
578  bool isConnectionLocked() const {
579  return connection_ != nullptr;
580  }
581 
582  // Returns the connection details associated with this statement. The connection is not locked by querying this property.
583  std::shared_ptr<ConnectionBase> connection() const {
584  return weakConnection_.lock();
585  }
586 
587  // Return the current statement state.
588  Statement::State state() const {
589  return state_;
590  }
591 
592  // Change the statement state. A statement in the EXECUTING state will lock the connection to prevent it from being
593  // destroyed, but a statement in any other state will unlock the connection causing the last reference to destroy the
594  // connection and will invalidate all iterators and rows.
595  void state(Statement::State newState) {
596  switch (newState) {
597  case Statement::DEAD:
598  case Statement::FINISHED:
599  case Statement::UNBOUND:
600  case Statement::READY:
601  invalidateIteratorsAndRows();
602  unlockConnection();
603  break;
604  case Statement::EXECUTING:
605  ASSERT_require(isConnectionLocked());
606  break;
607  }
608  state_ = newState;
609  }
610 
611  // Returns true if this statement has parameters that have not been bound to a value.
612  bool hasUnboundParameters() const {
613  ASSERT_forbid(state() == Statement::DEAD);
614  for (const Parameter &param: params_.values()) {
615  if (!param.isBound)
616  return true;
617  }
618  return false;
619  }
620 
621  // Causes all parameters to become unbound and changes the state to either UNBOUND or READY (depending on whether there are
622  // any parameters or not, respectively).
623  virtual void unbindAllParams() {
624  ASSERT_forbid(state() == Statement::DEAD);
625  for (Parameter &param: params_.values())
626  param.isBound = false;
627  state(params_.isEmpty() ? Statement::READY : Statement::UNBOUND);
628  }
629 
630  // Reset the statement by invalidating all iterators, unbinding all parameters, and changing the state to either UNBOUND or
631  // READY depending on whether or not it has any parameters.
632  virtual void reset(bool doUnbind) {
633  ASSERT_forbid(state() == Statement::DEAD);
634  invalidateIteratorsAndRows();
635  if (doUnbind) {
636  unbindAllParams();
637  } else {
638  state(hasUnboundParameters() ? Statement::UNBOUND : Statement::READY);
639  }
640  }
641 
642  // Bind a value to a parameter. If isRebind is set and the statement is in the EXECUTING state, then rewind back to the
643  // READY state, preserve all previous bindings, and adjust only the specified binding.
644  template<typename T>
645  void bind(const std::string &name, const T &value, bool isRebind) {
646  if (!connection())
647  throw Exception("connection is closed");
648  switch (state()) {
649  case Statement::DEAD:
650  throw Exception("statement is dead");
651  case Statement::FINISHED:
652  case Statement::EXECUTING:
653  reset(!isRebind);
654  // fall through
655  case Statement::READY:
656  case Statement::UNBOUND: {
657  if (!params_.exists(name))
658  throw Exception("no such parameter \"" + name + "\" in statement");
659  Parameter &param = params_[name];
660  bool wasUnbound = !param.isBound;
661  for (size_t idx: param.indexes) {
662  try {
663  bindLow(idx, value);
664  } catch (const Exception &e) {
665  if (param.indexes.size() > 1)
666  state(Statement::DEAD); // might be only partly bound now
667  throw e;
668  }
669  }
670  param.isBound = true;
671 
672  if (wasUnbound && !hasUnboundParameters())
673  state(Statement::READY);
674  break;
675  }
676  }
677  }
678 
679  // Bind a value to an optional parameter.
680  template<typename T>
681  void bind(const std::string &name, const Sawyer::Optional<T> &value, bool isRebind) {
682  if (value) {
683  bind(name, *value, isRebind);
684  } else {
685  bind(name, Nothing(), isRebind);
686  }
687  }
688 
689  // Driver-specific part of binding by specifying the 0-origin low-level "?" number and the value.
690  virtual void bindLow(size_t idx, int value) = 0;
691  virtual void bindLow(size_t idx, int64_t value) = 0;
692  virtual void bindLow(size_t idx, size_t value) = 0;
693  virtual void bindLow(size_t idx, double value) = 0;
694  virtual void bindLow(size_t idx, const std::string &value) = 0;
695  virtual void bindLow(size_t idx, const char *cstring) = 0;
696  virtual void bindLow(size_t idx, Nothing) = 0;
697  virtual void bindLow(size_t idx, const std::vector<uint8_t> &data) = 0;
698 
699  Iterator makeIterator() {
700  return Iterator(shared_from_this());
701  }
702 
703  // Begin execution of a statement in the READY state. If the statement is in the FINISHED or EXECUTING state it will be
704  // restarted.
705  Iterator begin() {
706  if (!connection())
707  throw Exception("connection is closed");
708  switch (state()) {
709  case Statement::DEAD:
710  throw Exception("statement is dead");
711  case Statement::UNBOUND: {
712  std::string s;
713  for (Parameters::Node &param: params_.nodes()) {
714  if (!param.value().isBound)
715  s += (s.empty() ? "" : ", ") + param.key();
716  }
717  ASSERT_forbid(s.empty());
718  throw Exception("unbound parameters: " + s);
719  }
720  case Statement::FINISHED:
721  case Statement::EXECUTING:
722  reset(false);
723  // fall through
724  case Statement::READY: {
725  if (!lockConnection())
726  throw Exception("connection has been closed");
727  state(Statement::EXECUTING);
728  rowNumber_ = 0;
729  Iterator iter = beginLow();
730  rowNumber_ = 0; // in case beginLow changed it
731  return iter;
732  }
733  }
734  ASSERT_not_reachable("invalid state");
735  }
736 
737  // The driver-specific component of "begin". The statement is guaranteed to be in the EXECUTING state when called,
738  // but could be in some other state after returning.
739  virtual Iterator beginLow() = 0;
740 
741  // Advance an executing statement to the next row
742  Iterator next() {
743  if (!connection())
744  throw Exception("connection is closed");
745  ASSERT_require(state() == Statement::EXECUTING); // no other way to get here
746  invalidateIteratorsAndRows();
747  ++rowNumber_;
748  return nextLow();
749  }
750 
751  // Current row number
752  size_t rowNumber() const {
753  return rowNumber_;
754  }
755 
756  // The driver-specific component of "next". The statement is guaranteed to be in the EXECUTING state when called, but
757  // could be in some other state after returning.
758  virtual Iterator nextLow() = 0;
759 
760  // Get a column value from the current row of result
761  template<typename T>
762  Optional<T> get(size_t columnIdx) {
763  if (!connection())
764  throw Exception("connection is closed");
765  ASSERT_require(state() == Statement::EXECUTING); // no other way to get here
766  if (columnIdx >= nColumns())
767  throw Exception("column index " + boost::lexical_cast<std::string>(columnIdx) + " is out of range");
768  return ColumnReader<T>()(this, columnIdx);
769  }
770 
771  // Number of columns returned by a query.
772  virtual size_t nColumns() const = 0;
773 
774  // Get the value of a particular column of the current row.
775  virtual Optional<std::string> getString(size_t idx) = 0;
776  virtual Optional<std::vector<std::uint8_t>> getBlob(size_t idx) = 0;
777 };
778 
779 template<typename T>
780 inline Optional<T>
781 ColumnReader<T>::operator()(StatementBase *stmt, size_t idx) {
782  std::string str;
783  if (!stmt->getString(idx).assignTo(str))
784  return Nothing();
785  return boost::lexical_cast<T>(str);
786 }
787 
788 template<>
789 inline Optional<std::vector<uint8_t>>
790 ColumnReader<std::vector<uint8_t>>::operator()(StatementBase *stmt, size_t idx) {
791  return stmt->getBlob(idx);
792 }
793 
794 inline Statement
795 ConnectionBase::makeStatement(const std::shared_ptr<Detail::StatementBase> &detail) {
796  return Statement(detail);
797 }
798 
799 } // namespace
800 
802 // Implementations Connection
804 
805 
806 inline Connection::Connection(const std::shared_ptr<Detail::ConnectionBase> &pimpl)
807  : pimpl_(pimpl) {}
808 
809 inline bool
810 Connection::isOpen() const {
811  return pimpl_ != nullptr;
812 }
813 
814 inline Connection&
815 Connection::close() {
816  pimpl_ = nullptr;
817  return *this;
818 }
819 
820 inline std::string
821 Connection::driverName() const {
822  if (pimpl_) {
823  return pimpl_->driverName();
824  } else {
825  return "";
826  }
827 }
828 
829 inline Statement
830 Connection::stmt(const std::string &sql) {
831  if (pimpl_) {
832  return pimpl_->prepareStatement(sql);
833  } else {
834  throw Exception("no active database connection");
835  }
836 }
837 
838 inline Connection&
839 Connection::run(const std::string &sql) {
840  stmt(sql).begin();
841  return *this;
842 }
843 
844 template<typename T>
845 inline Optional<T>
846 Connection::get(const std::string &sql) {
847  for (auto row: stmt(sql))
848  return row.get<T>(0);
849  return Nothing();
850 }
851 
852 inline size_t
853 Connection::lastInsert() const {
854  if (pimpl_) {
855  return pimpl_->lastInsert();
856  } else {
857  throw Exception("no active database connection");
858  }
859 }
860 
862 // Implementations for Statement
864 
865 inline Connection
866 Statement::connection() const {
867  if (pimpl_) {
868  return Connection(pimpl_->connection());
869  } else {
870  return Connection();
871  }
872 }
873 
874 template<typename T>
875 inline Statement&
876 Statement::bind(const std::string &name, const T &value) {
877  if (pimpl_) {
878  pimpl_->bind(name, value, false);
879  } else {
880  throw Exception("no active database connection");
881  }
882  return *this;
883 }
884 
885 template<typename T>
886 inline Statement&
887 Statement::rebind(const std::string &name, const T &value) {
888  if (pimpl_) {
889  pimpl_->bind(name, value, true);
890  } else {
891  throw Exception("no active database connection");
892  }
893  return *this;
894 }
895 
896 inline Iterator
897 Statement::begin() {
898  if (pimpl_) {
899  return pimpl_->begin();
900  } else {
901  throw Exception("no active database connection");
902  }
903 }
904 
905 inline Iterator
906 Statement::end() {
907  return Iterator();
908 }
909 
910 inline Statement&
911 Statement::run() {
912  begin();
913  return *this;
914 }
915 
916 template<typename T>
917 inline Optional<T>
918 Statement::get() {
919  Iterator row = begin();
920  if (row.isEnd())
921  throw Exception("query did not return a row");
922  return row->get<T>(0);
923 }
924 
926 // Implementations for Iterator
928 
929 inline
930 Iterator::Iterator(const std::shared_ptr<Detail::StatementBase> &stmt)
931  : row_(stmt) {}
932 
933 inline const Row&
934 Iterator::dereference() const {
935  if (isEnd())
936  throw Exception("dereferencing the end iterator");
937  if (row_.sequence_ != row_.stmt_->sequence())
938  throw Exception("iterator has been invalidated");
939  return row_;
940 }
941 
942 inline bool
943 Iterator::equal(const Iterator &other) const {
944  return row_.stmt_ == other.row_.stmt_ && row_.sequence_ == other.row_.sequence_;
945 }
946 
947 inline void
948 Iterator::increment() {
949  if (isEnd())
950  throw Exception("incrementing the end iterator");
951  *this = row_.stmt_->next();
952 }
953 
955 // Implementations for Row
957 
958 inline
959 Row::Row(const std::shared_ptr<Detail::StatementBase> &stmt)
960  : stmt_(stmt), sequence_(stmt ? stmt->sequence() : 0) {}
961 
962 template<typename T>
963 inline Optional<T>
964 Row::get(size_t columnIdx) const {
965  ASSERT_not_null(stmt_);
966  if (sequence_ != stmt_->sequence())
967  throw Exception("row has been invalidated");
968  return stmt_->get<T>(columnIdx);
969 }
970 
971 inline size_t
972 Row::rowNumber() const {
973  ASSERT_not_null(stmt_);
974  if (sequence_ != stmt_->sequence())
975  throw Exception("row has been invalidated");
976  return stmt_->rowNumber();
977 }
978 
979 } // namespace
980 } // namespace
981 
982 #endif
983 #endif
STL namespace.
Holds a value or nothing.
Definition: Optional.h:49
bool increment(Word *vec1, const BitRange &range1)
Increment.
Name space for the entire library.