using System; using System.Collections; using System.Collections.Specialized; using System.Configuration; using System.Text; using System.Data; using System.Data.SqlClient; using System.Reflection; using log4net; using Neo.Core; using Neo.Core.Util; using Neo.Framework; namespace Neo.SqlClient { /// /// Persists data from an ObjectContext into a SQL compliant database server /// and vice versa /// public class SqlDataStore : IDataStore { public const string LIKE_ESCAPE_CHAR = "\\"; //-------------------------------------------------------------------------------------- // Static fields and constructor //-------------------------------------------------------------------------------------- /// /// Static reference to a logging device /// protected static ILog logger = null; /// /// /// static SqlDataStore() { if(logger == null) logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType.ToString()); } //-------------------------------------------------------------------------------------- // Fields and constructor //-------------------------------------------------------------------------------------- private SqlTransaction transaction; private ArrayList processedRows; SqlConnection connection = null; /// /// Default constructor. Uses the connection string in the configuration key 'neo.sqlclient' or 'connectionstring' /// public SqlDataStore() : this(null) { } /// /// Creates a new connection using the supplied connection string /// /// public SqlDataStore(string connectionString) { logger.Debug("New SqlDataStore created"); if(connectionString == null) { NameValueCollection config = (NameValueCollection)ConfigurationSettings.GetConfig("neo.sqlclient"); if(config != null) connectionString = config["connectionstring"]; } connection = new SqlConnection(connectionString); } //-------------------------------------------------------------------------------------- // Explicit transactions //-------------------------------------------------------------------------------------- /// /// Marks the beginning of a transaction /// public virtual void BeginTransaction() { EnsureOpen(); transaction = connection.BeginTransaction(); processedRows = new ArrayList(); } /// /// Attempts to push all changes in this transaction to the database /// public virtual void CommitTransaction() { transaction.Commit(); transaction = null; foreach(DataRow r in processedRows) r.AcceptChanges(); processedRows = null; } /// /// Reverts the transaction to its original state /// public virtual void RollbackTransaction() { if(transaction != null) transaction.Rollback(); transaction = null; processedRows = null; } /// /// Makes sure the connection is in the open state, opening it if necessary /// void EnsureOpen() { if(connection.State != ConnectionState.Open) { connection.Open(); } } /// /// Closes the connection, if necessary /// void Close() { if(connection.State != ConnectionState.Closed && transaction == null) { connection.Close(); } } //-------------------------------------------------------------------------------------- // IDataStore impl (fetching rows) //-------------------------------------------------------------------------------------- /// /// Fetches rows from a single database table that match the specification /// /// IFetchSpecification determining which rows to return /// All matching rows from the table public DataTable FetchRows(IFetchSpecification fetchSpec) { DataTable table = new DataTable(fetchSpec.EntityMap.TableName); fetchSpec.EntityMap.UpdateSchema(table, SchemaUpdate.Basic); StringBuilder builder = new StringBuilder(); builder.Append("SELECT "); WriteColumns(builder, table, table.Columns); builder.Append(" FROM "); builder.Append(table.TableName); Qualifier qualifier = fetchSpec.Qualifier; if(qualifier != null) { builder.Append(" WHERE "); Hashtable parameters = new Hashtable(); WriteQualifier(fetchSpec.EntityMap, qualifier, builder, parameters); FillTable(table, builder.ToString(), parameters); } else { FillTable(table, builder.ToString(), null); } return table; } /// /// Fills the supplied table with the results of the query /// /// table to be filled /// SQL command to be executed /// list of parameters for the SQL command public void FillTable(DataTable table, string sqlCommand, IDictionary parameters) { SqlCommand cmd = CreateCommand(sqlCommand); if(parameters != null) { foreach(DictionaryEntry entry in parameters) SetParameter(cmd, table.Columns[(string)entry.Key], "", entry.Value); } if(logger.IsDebugEnabled) { logger.Debug(cmd.CommandText); foreach(SqlParameter p in cmd.Parameters) logger.Debug(" " + p.ParameterName + " = " + p.Value); } SqlDataAdapter adapter = new SqlDataAdapter(); adapter.SelectCommand = cmd; EnsureOpen(); adapter.Fill(table); if(logger.IsDebugEnabled) logger.Debug(String.Format("returned {0} rows", table.Rows.Count)); Close(); } //-------------------------------------------------------------------------------------- /// /// /// /// Map from Attribute to Column name and vice versa /// /// /// protected virtual void WriteQualifier(IEntityMap emap, Qualifier q, StringBuilder builder, IDictionary parameters) { if(q is ColumnQualifier) { WriteColumnQualifier(emap, (ColumnQualifier)q, builder, parameters); } else if(q is PropertyQualifier) { PropertyQualifier pq = q as PropertyQualifier; if(pq.Value is IEntityObject) { WriteRelationPropertyQualifier(builder, emap, pq); } else { q = ((PropertyQualifier)q).GetColumnQualifier(emap); WriteColumnQualifier(emap, (ColumnQualifier)q, builder, parameters); } } else if(q is ClauseQualifier) { WriteClauseQualifier(emap, (ClauseQualifier)q, builder, parameters); } else { throw new ArgumentException(String.Format("Invalid qualifier class {0}", q.GetType())); } } protected virtual void WriteRelationPropertyQualifier(StringBuilder builder, IEntityMap emap, PropertyQualifier pq) { IEntityObject eo = pq.Value as IEntityObject; ObjectContext context = eo.Context; FieldInfo[] fields = eo.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); DataRelation theDataRelation = null; DataColumn[] pks = null; foreach(FieldInfo field in fields) { if(typeof(ObjectRelationBase).IsAssignableFrom(field.FieldType)) { ObjectRelationBase orb = (ObjectRelationBase) field.GetValue(eo); if (orb.Relation.ChildTable.TableName.Equals(emap.TableName)) { theDataRelation = orb.Relation; pks = theDataRelation.ChildColumns; break; } } } if (theDataRelation == null) { throw new ApplicationException("Tried to fetch against an EntityObject with no relations: \n" + eo.ToString() + "\n " + pq.ToString()); } builder.Append(" exists ( "); builder.Append(" SELECT 1 "); builder.Append(" FROM " + eo.Row.Table.TableName); builder.Append(" WHERE "); foreach(DataColumn pk in pks) { builder.Append(emap.TableName + "." + pk.ColumnName + " = " + eo.Row.Table.TableName + "." + pk.ColumnName); builder.Append(" and "); } builder.Append(eo.Row.Table.TableName + "." + pks[0].ColumnName + " = '" + eo.Row[pks[0].ColumnName] + "'"); builder.Append(")"); } protected virtual void WriteColumnQualifier(IEntityMap emap, ColumnQualifier q, StringBuilder builder, IDictionary parameters) { builder.Append(q.Column); switch(q.Operator) { case QualifierOperator.Equal: builder.Append("="); break; case QualifierOperator.NotEqual: builder.Append("<>"); break; case QualifierOperator.GreaterThan: builder.Append(">"); break; case QualifierOperator.LessThan: builder.Append("<"); break; default: throw new ArgumentException("Invalid operator in qualifier."); } string paramName = ConvertToParameterName(q.Column, ""); builder.Append(paramName); parameters[q.Column] = q.Value; } /// /// /// /// Map from Attribute to Column name and vice versa /// /// /// protected virtual void WriteClauseQualifier(IEntityMap emap, ClauseQualifier q, StringBuilder builder, IDictionary parameters) { string conjunctor; switch(q.Conjunctor) { case QualifierConjunctor.And: conjunctor = " AND "; break; case QualifierConjunctor.Or: conjunctor = " OR "; break; default: throw new ArgumentException("Invalid conjunctor in qualifier."); } bool isFirstChild = true; builder.Append("("); foreach(Qualifier child in q.Qualifiers) { if(isFirstChild == false) builder.Append(conjunctor); isFirstChild = false; WriteQualifier(emap, child, builder, parameters); } builder.Append(")"); } //-------------------------------------------------------------------------------------- // IDataStore impl (saving changes) //-------------------------------------------------------------------------------------- /// /// Save changes made in the context to the database /// /// The ObjectContext containing the changed data /// A list of ChangeTable objects /// /// SaveChangesInObjectContext starts a new transaction then /// goes through all tables within the ObjectContext, processing changes in each table. /// If any errors occur, the whole transaction is rolled back. /// public ICollection SaveChangesInObjectContext(ObjectContext context) { ArrayList pkChangeTableList; PkChangeTable pkChangeTable; pkChangeTableList = new ArrayList(); try { EnsureOpen(); BeginTransaction(); // WARNING: This is naive and will not work if foreign key contstraints are // violated because SQL Server can't delay constraints. To fix we must ensure // that children are inserted before their parents and deleted after them. // Therefore, we can't do the saves simply table-by-table. StringBuilder errorString = new StringBuilder(); foreach(DataTable t in context.DataSet.Tables) { pkChangeTable = Save(t); if(pkChangeTable.Count > 0) pkChangeTableList.Add(pkChangeTable); if(t.HasErrors) { errorString.Append(String.Format("\n{0}:\n", t.TableName)); foreach(DataRow row in t.GetErrors()) errorString.Append(String.Format("{0}\n", row.RowError)); } } if(errorString.Length > 0) throw new DataStoreSaveException("Errors while saving: " + errorString.ToString()); CommitTransaction(); } catch(Exception e) { RollbackTransaction(); if(e is DataStoreSaveException) throw e; throw new DataStoreSaveException(e); } finally { Close(); } return pkChangeTableList; } /// /// Saves changes to any modified rows in the supplied table /// /// The table containing changes /// a PkChangeTable of modified rows /// /// A row is Inserted, Updated or Deleted as required /// /// Neo.SqlClient.SqlDataStore.InsertRow /// Neo.SqlClient.SqlDataStore.UpdateRow /// Neo.SqlClient.SqlDataStore.DeleteRow public virtual PkChangeTable Save(DataTable table) { PkChangeTable pkChangeTable; DataRow[] modrows; pkChangeTable = new PkChangeTable(table.TableName); modrows = table.Select("", "", DataViewRowState.Added | DataViewRowState.Deleted | DataViewRowState.ModifiedCurrent); foreach(DataRow row in modrows) { switch(row.RowState) { case DataRowState.Added: InsertRow(row, pkChangeTable); break; case DataRowState.Deleted: DeleteRow(row); break; case DataRowState.Modified: UpdateRow(row); break; } } return pkChangeTable; } //-------------------------------------------------------------------------------------- /// /// Inserts a new row into the database /// /// row containing data to be inserted /// PkChangeTable to record the newly changed data protected virtual void InsertRow(DataRow row, PkChangeTable pkChangeTable) { ArrayList columnList; StringBuilder builder; DataColumn pkColumn; bool usesIdentityColumn; int rowsAffected; pkColumn = row.Table.PrimaryKey[0]; usesIdentityColumn = ((row.Table.PrimaryKey.Length == 1) && (pkColumn.AutoIncrement)); columnList = new ArrayList(); foreach(DataColumn column in row.Table.Columns) { if((usesIdentityColumn == false) || (column != pkColumn)) columnList.Add(column); } builder = new StringBuilder(); builder.Append("INSERT INTO "); builder.Append(row.Table.TableName); builder.Append(" ("); WriteColumns(builder, row.Table, columnList); builder.Append(") VALUES ("); WriteParameters(builder, row.Table, "", columnList); builder.Append(")"); ArrayList parameters = new ArrayList(); foreach(DataColumn column in row.Table.Columns) parameters.Add(GetParameter(column, "", row[column])); rowsAffected = ExecuteNonQuery(builder.ToString(), parameters); if((usesIdentityColumn) && (rowsAffected == 1)) { object finalPkValue; finalPkValue = ExecuteScalar("SELECT @@IDENTITY AS NEW_ID", null); pkChangeTable.AddPkChange(row[pkColumn], finalPkValue); row[pkColumn] = finalPkValue; } PostProcessRow(row, (rowsAffected == 1), "Failed to insert row into database."); } /// /// Deletes a row from the database /// /// row containing data to be deleted protected virtual void DeleteRow(DataRow row) { StringBuilder builder; int rowsAffected; builder = new StringBuilder(); builder.Append("DELETE FROM "); builder.Append(row.Table.TableName); builder.Append(" WHERE"); WriteOptimisticLockMatch(builder, row.Table, ""); ArrayList parameters = new ArrayList(); foreach(DataColumn column in row.Table.Columns) { parameters.Add(GetParameter(column, "", row[column, DataRowVersion.Original])); } rowsAffected = ExecuteNonQuery(builder.ToString(), parameters); PostProcessRow(row, (rowsAffected == 1), "Failed to delete row in database."); } /// /// Updates a row's data into the database /// /// row to be updated protected virtual void UpdateRow(DataRow row) { StringBuilder builder; int rowsAffected; builder = new StringBuilder(); builder.Append("UPDATE "); builder.Append(row.Table.TableName); builder.Append(" SET"); WriteAllColumnsAndParameters(builder, row.Table, "", ", "); builder.Append(" WHERE"); WriteOptimisticLockMatch(builder, row.Table, "_ORIG"); ArrayList parameters = new ArrayList(); foreach(DataColumn column in row.Table.Columns) { parameters.Add(GetParameter(column, "", row[column])); parameters.Add(GetParameter(column, "_ORIG", row[column, DataRowVersion.Original], true)); } rowsAffected = ExecuteNonQuery(builder.ToString(), parameters); PostProcessRow(row, (rowsAffected == 1), "Failed to update row in database."); } private string EscapeLikeCharacters(string fieldValue) { string escapedVersion = fieldValue; escapedVersion = EscapeLikeCharacter(escapedVersion, "%", LIKE_ESCAPE_CHAR); escapedVersion = EscapeLikeCharacter(escapedVersion, "_", LIKE_ESCAPE_CHAR); escapedVersion = EscapeLikeCharacter(escapedVersion, "[", LIKE_ESCAPE_CHAR); escapedVersion = EscapeLikeCharacter(escapedVersion, "]", LIKE_ESCAPE_CHAR); escapedVersion = EscapeLikeCharacter(escapedVersion, "^", LIKE_ESCAPE_CHAR); return escapedVersion; } private string EscapeLikeCharacter(string fieldValue, string searchString, string escapeChar) { return fieldValue.Replace(searchString, escapeChar + searchString); } //-------------------------------------------------------------------------------------- // Other public methods //-------------------------------------------------------------------------------------- /// /// Deletes all rows from a table /// /// name of table to be cleared public virtual void ClearTable(string tableName) { StringBuilder builder; int rowsAffected; builder = new StringBuilder(); builder.Append("DELETE FROM "); builder.Append(tableName); rowsAffected = ExecuteNonQuery(builder.ToString(), null); } //-------------------------------------------------------------------------------------- // Internal helper //-------------------------------------------------------------------------------------- /// /// Creates a SqlCommand object using current connection properties /// /// SQL command to be executed /// A SqlCommand object using current connection properties protected virtual SqlCommand CreateCommand(string text) { SqlCommand command; if(transaction == null) command = CreateCommand(text, connection, null); else command = CreateCommand(text, connection, transaction); return command; } /// /// Creates a SqlCommand object using supplied connection properties /// /// SQL command to be executed /// Connection to be used /// transaction to contain this command /// A SqlCommand object using supplied connection properties protected virtual SqlCommand CreateCommand(string text, SqlConnection conn, SqlTransaction txn) { return new SqlCommand(text, conn, txn); } //-------------------------------------------------------------------------------------- /// /// Executes a SQL command on the database. Expects no return data /// /// SQL command to be executed /// List of parameters applied to the command /// number of rows affected by this command protected virtual int ExecuteNonQuery(string sqlCommand, IList parameters) { int rowsAffected; SqlCommand cmd = CreateCommand(sqlCommand); if(parameters != null) { foreach(SqlParameter parameter in parameters) cmd.Parameters.Add(parameter); } if(logger.IsDebugEnabled) { logger.Debug(cmd.CommandText); foreach(SqlParameter p in cmd.Parameters) logger.Debug(" " + p.ParameterName + " = " + p.Value); } EnsureOpen(); rowsAffected = cmd.ExecuteNonQuery(); Close(); if(logger.IsDebugEnabled) logger.Debug(String.Format("{0} rows affected", rowsAffected)); return rowsAffected; } /// /// Executes the query, and returns the first column of the first row in the result set returned by the query. /// Extra columns or rows are ignored. /// /// SQL Command to be executed /// List of parameters applied to the command /// The first column of the first row in the result set, or a null reference if the result set is empty. protected virtual object ExecuteScalar(string sqlCommand, IDictionary parameters) { SqlCommand cmd = CreateCommand(sqlCommand); if(parameters != null) { foreach(DictionaryEntry entry in parameters) SetParameter(cmd, (DataColumn)entry.Key, "", entry.Value); } if(logger.IsDebugEnabled) logger.Debug(cmd.CommandText); object result; if(logger.IsDebugEnabled) logger.Debug(cmd.CommandText); EnsureOpen(); result = cmd.ExecuteScalar(); Close(); if(logger.IsDebugEnabled) logger.Debug(String.Format("result = {0}", result)); return result; } //-------------------------------------------------------------------------------------- /// /// Handles results of row processing /// /// row that has just been processed /// true if the operation was successful /// any error message that may have been generated protected virtual void PostProcessRow(DataRow row, bool succeeded, string errorMessage) { if(succeeded) { if(transaction != null) processedRows.Add(row); else row.AcceptChanges(); } else row.RowError = errorMessage; } //-------------------------------------------------------------------------------------- /// /// Appends a comma-seperated list of column names to the supplied StringBuilder object /// /// StringBuilder to which column names are appended /// Not used /// A collection of columns protected void WriteColumns(StringBuilder builder, DataTable table, ICollection columns) { bool first = true; foreach(DataColumn c in columns) { if(first == false) builder.Append(", "); builder.Append(c.ColumnName); first = false; } } /// /// Appends a comma-seperated list of parameter names to the supplied StringBuilder object /// /// StringBuilder to which parameter names are appended /// Not used /// used to determine parameter name /// A collection of columns /// /// The ConvertToParameterName method is used to convert the column name and suffix into /// a parameter name /// /// ConvertToParameterName protected void WriteParameters(StringBuilder builder, DataTable table, string suffix, ICollection columns) { bool first = true; foreach(DataColumn c in columns) { if(first == false) builder.Append(", "); builder.Append(ConvertToParameterName(c.ColumnName, suffix)); first = false; } } /// /// Appends a list of column name-parameter pairs to the StringBuilder object /// /// StringBuilder to which text is appended /// DataTable containing column names to be appended /// used to generate parameter names /// string used to seperate column name-parameter pairs /// /// Text is added in the form ColumnName = ParameterName [seperator]. /// This method is currently only used by UpdateRow(). /// Auto-increment columns are ignored since they cannot be updated. /// protected void WriteAllColumnsAndParameters(StringBuilder builder, DataTable table, string suffix, string separator) { bool isFirst = true; for(int i = 0; i < table.Columns.Count; i++) { if (table.Columns[i].AutoIncrement == false) { if(!isFirst) { builder.Append(separator); } else { isFirst = false; } builder.Append(" "); builder.Append(table.Columns[i].ColumnName); builder.Append(" = "); builder.Append(ConvertToParameterName(table.Columns[i].ColumnName, suffix)); } } } /// /// Appends a SQL fragment that minimises the amount of data writes by comparing existing data with new data /// /// StringBuilder to which text is appended /// DataTable containing column names /// used to generate parameter names protected void WriteOptimisticLockMatch(StringBuilder builder, DataTable table, string suffix) { // This is not really correct, we should exclude BLOBs. Then again, we don't do BLOBs yet. for(int i = 0; i < table.Columns.Count; i++) { if (i > 0) { builder.Append(" AND"); } builder.Append(" (("); builder.Append(table.Columns[i].ColumnName); if (table.Columns[i].DataType != Type.GetType("System.String")) { builder.Append(" = "); builder.Append(ConvertToParameterName(table.Columns[i].ColumnName, suffix)); } else { builder.Append(" LIKE "); //This could cause us grief for some string values builder.Append(ConvertToParameterName(table.Columns[i].ColumnName, suffix)); builder.Append(" ESCAPE '" + LIKE_ESCAPE_CHAR + "'"); } builder.Append(" ) OR (("); builder.Append(table.Columns[i].ColumnName); builder.Append(" IS NULL) AND ("); builder.Append(ConvertToParameterName(table.Columns[i].ColumnName, suffix)); builder.Append(" IS NULL)))"); } } //-------------------------------------------------------------------------------------- /// /// Creates a new SqlParameter based upon the supplied DataColumn and adds it to the SqlCommand /// /// SqlCommand object to which the parameter is to be added /// DataColumn containing column information /// used to generate parameter names /// the value to assign to the parameter /// newly created SqlParameter protected virtual SqlParameter SetParameter(SqlCommand command, DataColumn pcolumn, string suffix, object pvalue) { SqlDbType dbtype; SqlParameter param; if((dbtype = ConvertToDbType(pcolumn.DataType)) == SqlDbType.Variant) throw new ArgumentException(String.Format("Cannot handle data type {0} of column {1} in table {2}.", pcolumn.DataType.ToString(), pcolumn.ColumnName, pcolumn.Table.TableName)); param = command.Parameters.Add(ConvertToParameterName(pcolumn.ColumnName, suffix), dbtype, Math.Max(0, pcolumn.MaxLength), pcolumn.ColumnName); param.IsNullable = pcolumn.AllowDBNull; param.Value = pvalue; return param; } /// /// Creates a new SqlParameter based upon the supplied DataColumn /// /// DataColumn containing column information /// used to generate parameter names /// the value to assign to the parameter /// newly created SqlParameter protected virtual SqlParameter GetParameter(DataColumn pcolumn, string suffix, object pvalue) { SqlDbType dbtype; SqlParameter param; if((dbtype = ConvertToDbType(pcolumn.DataType)) == SqlDbType.Variant) throw new ArgumentException(String.Format("Cannot handle data type {0} of column {1} in table {2}.", pcolumn.DataType.ToString(), pcolumn.ColumnName, pcolumn.Table.TableName)); param = new SqlParameter(ConvertToParameterName(pcolumn.ColumnName, suffix), dbtype, Math.Max(0, pcolumn.MaxLength), pcolumn.ColumnName); param.IsNullable = pcolumn.AllowDBNull; param.Value = pvalue; return param; } protected virtual SqlParameter GetParameter(DataColumn pcolumn, string suffix, object pvalue, bool escapeValue) { if (escapeValue == true && pvalue != DBNull.Value && pcolumn.DataType == Type.GetType("System.String")) { return GetParameter(pcolumn, suffix, EscapeLikeCharacters(pvalue.ToString())); } else { return GetParameter(pcolumn, suffix, pvalue); } } //-------------------------------------------------------------------------------------- /// /// Converts a column name to a parameter name /// /// column name /// suffix appended to column name /// parameter name protected virtual string ConvertToParameterName(string column, string suffix) { return "@"+column+suffix; } private static Hashtable dbTypeTable; /// /// Converts a System.* type to an equivalent SqlDbType.* type /// /// type to convert /// equivalent SqlDbType.* type protected virtual SqlDbType ConvertToDbType(Type t) { if(dbTypeTable == null) { dbTypeTable = new Hashtable(); dbTypeTable.Add(typeof(System.Boolean), SqlDbType.Bit); dbTypeTable.Add(typeof(System.Int16), SqlDbType.SmallInt); dbTypeTable.Add(typeof(System.Int32), SqlDbType.Int); dbTypeTable.Add(typeof(System.Int64), SqlDbType.BigInt); dbTypeTable.Add(typeof(System.Double), SqlDbType.Float); dbTypeTable.Add(typeof(System.Decimal), SqlDbType.Decimal); dbTypeTable.Add(typeof(System.String), SqlDbType.VarChar); dbTypeTable.Add(typeof(System.DateTime), SqlDbType.DateTime); dbTypeTable.Add(typeof(System.Byte[]), SqlDbType.VarBinary); dbTypeTable.Add(typeof(System.Guid), SqlDbType.UniqueIdentifier); } SqlDbType dbtype; try { dbtype = (SqlDbType)dbTypeTable[t]; } catch { dbtype = SqlDbType.Variant; } return dbtype; } } }