using System.Runtime.InteropServices.ComTypes; using System.Data; using Microsoft.EntityFrameworkCore.Infrastructure; using System.Data.Common; using System.Reflection; using Microsoft.Data.SqlClient; using Microsoft.EntityFrameworkCore; using System.Collections.Generic; using System; using Bowin.Common.Data; namespace OrderSystem.Services.Common { public static class DbContextExtensions { private static void CombineParams(ref DbCommand command, params object[] parameters) { if (parameters != null) { foreach (SqlParameter parameter in parameters) { if (!parameter.ParameterName.Contains("@")) parameter.ParameterName = $"@{parameter.ParameterName}"; command.Parameters.Add(parameter); } } } private static DbCommand CreateCommand(DatabaseFacade facade, string sql, out DbConnection dbConn, params object[] parameters) { DbConnection conn = facade.GetDbConnection(); dbConn = conn; conn.Open(); DbCommand cmd = conn.CreateCommand(); if (facade.IsSqlServer()) { cmd.CommandText = sql; CombineParams(ref cmd, parameters); } return cmd; } public static DataTable SqlQuery(this DatabaseFacade facade, string sql, params object[] parameters) { DbCommand cmd = CreateCommand(facade, sql, out DbConnection conn, parameters); DbDataReader reader = cmd.ExecuteReader(); DataTable dt = new DataTable(); try { dt.Load(reader); reader.Close(); conn.Close(); } catch (System.Data.DataException e) { System.Data.DataRow[] rowsInError; System.Text.StringBuilder sbError = new System.Text.StringBuilder(); // Test if the table has errors. If not, skip it. if (dt.HasErrors) { // Get an array of all rows with errors. rowsInError = dt.GetErrors(); // Print the error of each column in each row. for (int i = 0; i < rowsInError.Length; i++) { foreach (System.Data.DataColumn column in dt.Columns) { sbError.Append(column.ColumnName + " " + rowsInError[i].GetColumnError(column)); } // Clear the row errors rowsInError[i].ClearErrors(); } } } return dt; } public static IEnumerable SqlQuery(this DatabaseFacade facade, string sql, params object[] parameters) where T : class, new() { DataTable dt = SqlQuery(facade, sql, parameters); return dt.ToEnumerable(); } public static IEnumerable ToEnumerable(this DataTable dt) where T : class, new() { PropertyInfo[] propertyInfos = typeof(T).GetProperties(); T[] ts = new T[dt.Rows.Count]; int i = 0; foreach (DataRow row in dt.Rows) { T t = new T(); foreach (PropertyInfo p in propertyInfos) { if (dt.Columns.IndexOf(p.Name) != -1 && row[p.Name] != DBNull.Value) p.SetValue(t, row[p.Name], null); } ts[i] = t; i++; } return ts; } } }