using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Threading.Tasks; using MongoDB.Bson; using MongoDB.Driver; using Streetwriters.Data.Attributes; using Streetwriters.Data.Interfaces; namespace Streetwriters.Data.Repositories { public class Repository where TEntity : class { protected readonly IDbContext dbContext; protected IMongoCollection Collection { get; set; } public Repository(IDbContext _dbContext) { dbContext = _dbContext; Collection = GetCollection(); } private protected IMongoCollection GetCollection() { var attribute = (BsonCollectionAttribute)typeof(TEntity).GetCustomAttributes( typeof(BsonCollectionAttribute), true).FirstOrDefault(); if (string.IsNullOrEmpty(attribute.CollectionName) || string.IsNullOrEmpty(attribute.DatabaseName)) throw new Exception("Could not get a valid collection or database name."); return dbContext.GetCollection(attribute.DatabaseName, attribute.CollectionName); } public virtual void Insert(TEntity obj) { dbContext.AddCommand((handle, ct) => Collection.InsertOneAsync(handle, obj, null, ct)); } public virtual Task InsertAsync(TEntity obj) { return Collection.InsertOneAsync(obj); } public virtual void Upsert(TEntity obj, Expression> filterExpression) { dbContext.AddCommand((handle, ct) => Collection.ReplaceOneAsync(handle, filterExpression, obj, new ReplaceOptions { IsUpsert = true }, ct)); } public virtual Task UpsertAsync(TEntity obj, Expression> filterExpression) { return Collection.ReplaceOneAsync(filterExpression, obj, new ReplaceOptions { IsUpsert = true }); } public virtual async Task FindOneAsync(Expression> filterExpression) { var data = await Collection.FindAsync(filterExpression); return data.FirstOrDefault(); } public virtual async Task GetAsync(string id) { var data = await Collection.FindAsync(Builders.Filter.Eq("_id", ObjectId.Parse(id))); return data.FirstOrDefault(); } public virtual async Task> FindAsync(Expression> filterExpression) { var data = await Collection.FindAsync(filterExpression); return data.ToList(); } public virtual async Task> GetAllAsync() { var all = await Collection.FindAsync(Builders.Filter.Empty); return all.ToList(); } public virtual void Update(string id, TEntity obj) { dbContext.AddCommand((handle, ct) => Collection.ReplaceOneAsync(handle, Builders.Filter.Eq("_id", ObjectId.Parse(id)), obj, cancellationToken: ct)); } public virtual Task UpdateAsync(string id, TEntity obj) { return Collection.ReplaceOneAsync(Builders.Filter.Eq("_id", ObjectId.Parse(id)), obj); } public virtual void DeleteById(string id) { dbContext.AddCommand((handle, ct) => Collection.DeleteOneAsync(handle, Builders.Filter.Eq("_id", ObjectId.Parse(id)), cancellationToken: ct)); } public virtual Task DeleteByIdAsync(string id) { return Collection.DeleteOneAsync(Builders.Filter.Eq("_id", ObjectId.Parse(id))); } public virtual void Delete(Expression> filterExpression) { dbContext.AddCommand((handle, ct) => Collection.DeleteOneAsync(handle, filterExpression, cancellationToken: ct)); } public virtual void DeleteMany(Expression> filterExpression) { dbContext.AddCommand((handle, ct) => Collection.DeleteManyAsync(handle, filterExpression, cancellationToken: ct)); } public virtual Task DeleteAsync(Expression> filterExpression) { return Collection.DeleteOneAsync(filterExpression); } public virtual Task DeleteManyAsync(Expression> filterExpression) { return Collection.DeleteManyAsync(filterExpression); } } }