Skip to content

Commit

Permalink
Correcting the tenant id on all multi-tenancy *just in case*. Closes G…
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremydmiller committed May 21, 2024
1 parent af93884 commit 427dd83
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 24 deletions.
9 changes: 9 additions & 0 deletions docs/configuration/multitenancy.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ First off, let's try to answer the obvious questions you probably have:
* *Does the `IDocumentStore.Advanced` features work for multiple databases?* - This is a little more complicated, but the answer is still yes. See the very last section on administering databases.
* *Can this strategy use different database schemas in the same database?* - **That's a hard no.** The databases have to be identical in all structures.

## Tenant Id Case Sensitivity

Hey, we've all been there. Our perfectly crafted code fails because of a @#$%#@%ing case sensitivity string comparison.
That's unfortunately happened to Marten users with the `tenantId` values passed into Marten, and it's likely to happen
again. To guard against that, you can force Marten to convert all supplied tenant ids from the outside world to either
upper or lower case to try to stop these kinds of case sensitivity bugs in their tracks like so:

snippet: sample_using_tenant_id_style

## Static Database to Tenant Mapping

::: info
Expand Down
2 changes: 1 addition & 1 deletion src/Marten.PLv8/StoreOptionsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public static async Task TransformAsync(this IDocumentStore store, string tenant
{
var s = store.As<DocumentStore>();

var tenant = await s.Tenancy.GetTenantAsync(tenantId).ConfigureAwait(false);
var tenant = await s.Tenancy.GetTenantAsync(store.Options.MaybeCorrectTenantId(tenantId)).ConfigureAwait(false);

await tenant.Database.EnsureStorageExistsAsync(typeof(TransformSchema), token).ConfigureAwait(false);
using var transforms = new DocumentTransforms(s, tenant);
Expand Down
19 changes: 19 additions & 0 deletions src/Marten.Testing/Examples/MultiTenancy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@ namespace Marten.Testing.Examples;

public class MultiTenancy
{
public static void configuring_tenant_id_rules()
{
#region sample_using_tenant_id_style

var store = DocumentStore.For(opts =>
{
// This is the default
opts.TenantIdStyle = TenantIdStyle.CaseSensitive;

// Or opt into this behavior:
opts.TenantIdStyle = TenantIdStyle.ForceLowerCase;

// Or force all tenant ids to be converted to upper case internally
opts.TenantIdStyle = TenantIdStyle.ForceUpperCase;
});

#endregion
}

[Fact]
public void use_multiple_tenants()
{
Expand Down
7 changes: 4 additions & 3 deletions src/Marten/AdvancedOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public async Task ResetHiloSequenceFloor<T>(long floor)
/// <param name="floor"></param>
public async Task ResetHiloSequenceFloor<T>(string tenantId, long floor)
{
tenantId = _store.Options.MaybeCorrectTenantId(tenantId);
var tenant = await _store.Tenancy.GetTenantAsync(tenantId).ConfigureAwait(false);
await tenant.Database.ResetHiloSequenceFloor<T>(floor).ConfigureAwait(false);
}
Expand All @@ -115,7 +116,7 @@ public async Task<EventStoreStatistics> FetchEventStoreStatistics(string? tenant
{
var database = tenantId == null
? _store.Tenancy.Default.Database
: (await _store.Tenancy.GetTenantAsync(tenantId).ConfigureAwait(false)).Database;
: (await _store.Tenancy.GetTenantAsync(_store.Options.MaybeCorrectTenantId(tenantId)).ConfigureAwait(false)).Database;

return await database.FetchEventStoreStatistics(token).ConfigureAwait(false);
}
Expand All @@ -134,7 +135,7 @@ public async Task<IReadOnlyList<ShardState>> AllProjectionProgress(string? tenan
{
var database = tenantId == null
? _store.Tenancy.Default.Database
: (await _store.Tenancy.GetTenantAsync(tenantId).ConfigureAwait(false)).Database;
: (await _store.Tenancy.GetTenantAsync(_store.Options.MaybeCorrectTenantId(tenantId)).ConfigureAwait(false)).Database;

return await database.AllProjectionProgress(token).ConfigureAwait(false);
}
Expand All @@ -153,7 +154,7 @@ public async Task<long> ProjectionProgressFor(ShardName name, string? tenantId =
{
var tenant = tenantId == null
? _store.Tenancy.Default
: await _store.Tenancy.GetTenantAsync(tenantId).ConfigureAwait(false);
: await _store.Tenancy.GetTenantAsync(_store.Options.MaybeCorrectTenantId(tenantId)).ConfigureAwait(false);
var database = tenant.Database;

return await database.ProjectionProgressFor(name, token).ConfigureAwait(false);
Expand Down
37 changes: 24 additions & 13 deletions src/Marten/DocumentStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using Marten.Internal.Sessions;
using Marten.Services;
using Marten.Storage;
using Microsoft.CodeAnalysis.Options;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Weasel.Core.Migrations;
Expand Down Expand Up @@ -139,15 +140,15 @@ public void BulkInsert<T>(string tenantId, IReadOnlyCollection<T> documents,
BulkInsertMode mode = BulkInsertMode.InsertsOnly,
int batchSize = 1000)
{
var bulkInsertion = new BulkInsertion(Tenancy.GetTenant(tenantId), Options);
var bulkInsertion = new BulkInsertion(Tenancy.GetTenant(Options.MaybeCorrectTenantId(tenantId)), Options);
bulkInsertion.BulkInsert(documents, mode, batchSize);
}

public void BulkInsertDocuments(string tenantId, IEnumerable<object> documents,
BulkInsertMode mode = BulkInsertMode.InsertsOnly,
int batchSize = 1000)
{
var bulkInsertion = new BulkInsertion(Tenancy.GetTenant(tenantId), Options);
var bulkInsertion = new BulkInsertion(Tenancy.GetTenant(Options.MaybeCorrectTenantId(tenantId)), Options);
bulkInsertion.BulkInsertDocuments(documents, mode, batchSize);
}

Expand All @@ -172,7 +173,7 @@ public async Task BulkInsertAsync<T>(string tenantId, IReadOnlyCollection<T> doc
BulkInsertMode mode = BulkInsertMode.InsertsOnly, int batchSize = 1000,
CancellationToken cancellation = default)
{
var bulkInsertion = new BulkInsertion(await Tenancy.GetTenantAsync(tenantId).ConfigureAwait(false), Options);
var bulkInsertion = new BulkInsertion(await Tenancy.GetTenantAsync(Options.MaybeCorrectTenantId(tenantId)).ConfigureAwait(false), Options);
await bulkInsertion.BulkInsertAsync(documents, mode, batchSize, cancellation).ConfigureAwait(false);
}

Expand All @@ -188,7 +189,7 @@ public async Task BulkInsertDocumentsAsync(string tenantId, IEnumerable<object>
BulkInsertMode mode = BulkInsertMode.InsertsOnly,
int batchSize = 1000, CancellationToken cancellation = default)
{
var bulkInsertion = new BulkInsertion(await Tenancy.GetTenantAsync(tenantId).ConfigureAwait(false), Options);
var bulkInsertion = new BulkInsertion(await Tenancy.GetTenantAsync(Options.MaybeCorrectTenantId(tenantId)).ConfigureAwait(false), Options);
await bulkInsertion.BulkInsertDocumentsAsync(documents, mode, batchSize, cancellation).ConfigureAwait(false);
}

Expand Down Expand Up @@ -231,7 +232,7 @@ public IDocumentSession OpenSession(
DocumentTracking tracking = DocumentTracking.IdentityOnly,
IsolationLevel isolationLevel = IsolationLevel.ReadCommitted
) =>
openSession(new SessionOptions { Tracking = tracking, IsolationLevel = isolationLevel, TenantId = tenantId });
openSession(new SessionOptions { Tracking = tracking, IsolationLevel = isolationLevel, TenantId = Options.MaybeCorrectTenantId(tenantId) });

public IDocumentSession IdentitySession(IsolationLevel isolationLevel = IsolationLevel.ReadCommitted) =>
IdentitySession(new SessionOptions { IsolationLevel = isolationLevel });
Expand All @@ -240,7 +241,7 @@ public IDocumentSession IdentitySession(
string tenantId,
IsolationLevel isolationLevel = IsolationLevel.ReadCommitted
) =>
IdentitySession(new SessionOptions { IsolationLevel = isolationLevel, TenantId = tenantId });
IdentitySession(new SessionOptions { IsolationLevel = isolationLevel, TenantId = Options.MaybeCorrectTenantId(tenantId) });

public IDocumentSession IdentitySession(SessionOptions options)
{
Expand All @@ -261,7 +262,7 @@ public Task<IDocumentSession> IdentitySerializableSessionAsync(
CancellationToken cancellation = default
) =>
IdentitySerializableSessionAsync(
new SessionOptions { IsolationLevel = IsolationLevel.Serializable, TenantId = tenantId },
new SessionOptions { IsolationLevel = IsolationLevel.Serializable, TenantId = Options.MaybeCorrectTenantId(tenantId) },
cancellation
);

Expand All @@ -282,7 +283,7 @@ public IDocumentSession DirtyTrackedSession(
string tenantId,
IsolationLevel isolationLevel = IsolationLevel.ReadCommitted
) =>
DirtyTrackedSession(new SessionOptions { IsolationLevel = isolationLevel, TenantId = tenantId });
DirtyTrackedSession(new SessionOptions { IsolationLevel = isolationLevel, TenantId = Options.MaybeCorrectTenantId(tenantId) });

public IDocumentSession DirtyTrackedSession(SessionOptions options)
{
Expand All @@ -301,7 +302,7 @@ public Task<IDocumentSession> DirtyTrackedSerializableSessionAsync(
CancellationToken cancellation = default
) =>
DirtyTrackedSerializableSessionAsync(
new SessionOptions { IsolationLevel = IsolationLevel.Serializable, TenantId = tenantId }, cancellation);
new SessionOptions { IsolationLevel = IsolationLevel.Serializable, TenantId = Options.MaybeCorrectTenantId(tenantId) }, cancellation);

public Task<IDocumentSession> DirtyTrackedSerializableSessionAsync(
SessionOptions options,
Expand All @@ -320,7 +321,7 @@ public IDocumentSession LightweightSession(
string tenantId,
IsolationLevel isolationLevel = IsolationLevel.ReadCommitted
) =>
LightweightSession(new SessionOptions { IsolationLevel = isolationLevel, TenantId = tenantId });
LightweightSession(new SessionOptions { IsolationLevel = isolationLevel, TenantId = Options.MaybeCorrectTenantId(tenantId) });

public IDocumentSession LightweightSession(SessionOptions options)
{
Expand All @@ -339,7 +340,7 @@ public Task<IDocumentSession> LightweightSerializableSessionAsync(
CancellationToken cancellation = default
) =>
LightweightSerializableSessionAsync(
new SessionOptions { IsolationLevel = IsolationLevel.Serializable, TenantId = tenantId }, cancellation);
new SessionOptions { IsolationLevel = IsolationLevel.Serializable, TenantId = Options.MaybeCorrectTenantId(tenantId) }, cancellation);

public Task<IDocumentSession> LightweightSerializableSessionAsync(
SessionOptions options,
Expand All @@ -362,7 +363,7 @@ public IQuerySession QuerySession() =>
QuerySession(Marten.Storage.Tenancy.DefaultTenantId);

public IQuerySession QuerySession(string tenantId) =>
QuerySession(new SessionOptions { TenantId = tenantId });
QuerySession(new SessionOptions { TenantId = Options.MaybeCorrectTenantId(tenantId) });

public async Task<IQuerySession> QuerySerializableSessionAsync(
SessionOptions options,
Expand All @@ -383,13 +384,18 @@ public Task<IQuerySession> QuerySerializableSessionAsync(
CancellationToken cancellation = default
) =>
QuerySerializableSessionAsync(
new SessionOptions { TenantId = tenantId, IsolationLevel = IsolationLevel.Serializable }, cancellation);
new SessionOptions { TenantId = Options.MaybeCorrectTenantId(tenantId), IsolationLevel = IsolationLevel.Serializable }, cancellation);

public IProjectionDaemon BuildProjectionDaemon(
string? tenantIdOrDatabaseIdentifier = null,
ILogger? logger = null
)
{
if (tenantIdOrDatabaseIdentifier.IsNotEmpty())
{
tenantIdOrDatabaseIdentifier = Options.MaybeCorrectTenantId(tenantIdOrDatabaseIdentifier);
}

AssertTenantOrDatabaseIdentifierIsValid(tenantIdOrDatabaseIdentifier);

logger ??= new NulloLogger();
Expand All @@ -408,6 +414,11 @@ public async ValueTask<IProjectionDaemon> BuildProjectionDaemonAsync(
ILogger? logger = null
)
{
if (tenantIdOrDatabaseIdentifier.IsNotEmpty())
{
tenantIdOrDatabaseIdentifier = Options.MaybeCorrectTenantId(tenantIdOrDatabaseIdentifier);
}

AssertTenantOrDatabaseIdentifierIsValid(tenantIdOrDatabaseIdentifier);

logger ??= Options.LogFactory?.CreateLogger<ProjectionDaemon>() ?? Options.DotNetLogger ?? NullLogger.Instance;
Expand Down
2 changes: 2 additions & 0 deletions src/Marten/IReadOnlyStoreOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,6 @@ public interface IReadOnlyStoreOptions
/// Get database schema names for configured tables
/// </summary>
IDocumentSchemaResolver Schema { get; }

string MaybeCorrectTenantId(string tenantId);
}
4 changes: 2 additions & 2 deletions src/Marten/Services/SessionOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ internal IConnectionLifetime Initialize(DocumentStore store, CommandRunnerMode m
OpenTelemetryOptions telemetryOptions)
{
Mode = mode;
Tenant ??= TenantId != Tenancy.DefaultTenantId ? store.Tenancy.GetTenant(TenantId) : store.Tenancy.Default;
Tenant ??= TenantId != Tenancy.DefaultTenantId ? store.Tenancy.GetTenant(store.Options.MaybeCorrectTenantId(TenantId)) : store.Tenancy.Default;

if (!AllowAnyTenant && !store.Options.Advanced.DefaultTenantUsageEnabled &&
Tenant.TenantId == Tenancy.DefaultTenantId)
Expand Down Expand Up @@ -156,7 +156,7 @@ internal async Task<IConnectionLifetime> InitializeAsync(DocumentStore store, Co
{
Mode = mode;
Tenant ??= TenantId != Tenancy.DefaultTenantId
? await store.Tenancy.GetTenantAsync(TenantId).ConfigureAwait(false)
? await store.Tenancy.GetTenantAsync(store.Options.MaybeCorrectTenantId(TenantId)).ConfigureAwait(false)
: store.Tenancy.Default;

if (!AllowAnyTenant && !store.Options.Advanced.DefaultTenantUsageEnabled &&
Expand Down
8 changes: 7 additions & 1 deletion src/Marten/Storage/CompositeDocumentCleaner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ namespace Marten.Storage;
public class CompositeDocumentCleaner: IDocumentCleaner
{
private readonly ITenancy _tenancy;
private readonly StoreOptions _options;

public CompositeDocumentCleaner(ITenancy tenancy)
public CompositeDocumentCleaner(ITenancy tenancy, StoreOptions options)
{
_tenancy = tenancy;
_options = options;
}


Expand Down Expand Up @@ -90,6 +92,8 @@ public async Task DeleteSingleEventStreamAsync(Guid streamId, string? tenantId =
await applyToAll(d => d.DeleteSingleEventStreamAsync(streamId, tenantId, ct)).ConfigureAwait(false);
}

tenantId = _options.MaybeCorrectTenantId(tenantId);

var tenant = await _tenancy.GetTenantAsync(tenantId).ConfigureAwait(false);
await tenant.Database.DeleteSingleEventStreamAsync(streamId, tenantId, ct).ConfigureAwait(false);
}
Expand All @@ -107,6 +111,8 @@ public async Task DeleteSingleEventStreamAsync(string streamId, string? tenantId
await applyToAll(d => d.DeleteSingleEventStreamAsync(streamId, tenantId, ct)).ConfigureAwait(false);
}

tenantId = _options.MaybeCorrectTenantId(tenantId);

var tenant = await _tenancy.GetTenantAsync(tenantId).ConfigureAwait(false);
await tenant.Database.DeleteSingleEventStreamAsync(streamId, tenantId, ct).ConfigureAwait(false);
}
Expand Down
10 changes: 9 additions & 1 deletion src/Marten/Storage/MasterTableTenancy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public MasterTableTenancy(StoreOptions options, MasterTableTenancyOptions tenanc
}

_schemaName = tenancyOptions.SchemaName;
Cleaner = new CompositeDocumentCleaner(this);
Cleaner = new CompositeDocumentCleaner(this, _options);

_tenantDatabase = new Lazy<TenantLookupDatabase>(() =>
new TenantLookupDatabase(_options, _dataSource.Value, tenancyOptions.SchemaName));
Expand Down Expand Up @@ -158,6 +158,7 @@ public async ValueTask<IReadOnlyList<IDatabase>> BuildDatabases()
while (await reader.ReadAsync().ConfigureAwait(false))
{
var tenantId = await reader.GetFieldValueAsync<string>(0).ConfigureAwait(false);
tenantId = _options.MaybeCorrectTenantId(tenantId);

// Be idempotent, don't duplicate
if (_databases.Contains(tenantId))
Expand Down Expand Up @@ -191,6 +192,7 @@ public async ValueTask<IReadOnlyList<IDatabase>> BuildDatabases()

public Tenant GetTenant(string tenantId)
{
tenantId = _options.MaybeCorrectTenantId(tenantId);
if (_databases.TryFind(tenantId, out var database))
{
return new Tenant(tenantId, database);
Expand All @@ -211,6 +213,7 @@ public Tenant GetTenant(string tenantId)

public async ValueTask<Tenant> GetTenantAsync(string tenantId)
{
tenantId = _options.MaybeCorrectTenantId(tenantId);
if (_databases.TryFind(tenantId, out var database))
{
return new Tenant(tenantId, database);
Expand All @@ -229,6 +232,7 @@ public async ValueTask<Tenant> GetTenantAsync(string tenantId)

public async ValueTask<IMartenDatabase> FindOrCreateDatabase(string tenantIdOrDatabaseIdentifier)
{
tenantIdOrDatabaseIdentifier = _options.MaybeCorrectTenantId(tenantIdOrDatabaseIdentifier);
if (_databases.TryFind(tenantIdOrDatabaseIdentifier, out var database))
{
return database;
Expand All @@ -245,13 +249,15 @@ public async ValueTask<IMartenDatabase> FindOrCreateDatabase(string tenantIdOrDa

public bool IsTenantStoredInCurrentDatabase(IMartenDatabase database, string tenantId)
{
tenantId = _options.MaybeCorrectTenantId(tenantId);
return database.Identifier == tenantId;
}

public PostgresqlDatabase TenantDatabase => _tenantDatabase.Value;

public async Task DeleteDatabaseRecordAsync(string tenantId)
{
tenantId = _options.MaybeCorrectTenantId(tenantId);
await maybeApplyChanges(_tenantDatabase.Value).ConfigureAwait(false);

await _dataSource.Value
Expand All @@ -270,6 +276,7 @@ await _dataSource.Value.CreateCommand($"delete from {_schemaName}.{TenantTable.T

public async Task AddDatabaseRecordAsync(string tenantId, string connectionString)
{
tenantId = _options.MaybeCorrectTenantId(tenantId);
await _dataSource.Value
.CreateCommand(
$"insert into {_schemaName}.{TenantTable.TableName} (tenant_id, connection_string) values (:id, :connection) on conflict (tenant_id) do update set connection_string = :connection")
Expand Down Expand Up @@ -318,6 +325,7 @@ private async Task seedDatabasesAsync(NpgsqlConnection conn)

private async Task<MartenDatabase?> tryFindTenantDatabase(string tenantId)
{
tenantId = _options.MaybeCorrectTenantId(tenantId);
var connectionString = (string)await _dataSource.Value
.CreateCommand($"select connection_string from {_schemaName}.{TenantTable.TableName} where tenant_id = :id")
.With("id", tenantId)
Expand Down
4 changes: 2 additions & 2 deletions src/Marten/Storage/SingleServerMultiTenancy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ StoreOptions options
): base(dataSourceFactory, masterConnectionString)
{
_options = options;
Cleaner = new CompositeDocumentCleaner(this);
Cleaner = new CompositeDocumentCleaner(this, options);

_masterDataSource =
new Lazy<NpgsqlDataSource>(() => options.NpgsqlDataSourceFactory.Create(masterConnectionString));
Expand Down Expand Up @@ -152,7 +152,7 @@ public async ValueTask<Tenant> GetTenantAsync(string tenantId)

public new async ValueTask<IMartenDatabase> FindOrCreateDatabase(string tenantIdOrDatabaseIdentifier)
{
var tenant = await GetTenantAsync(tenantIdOrDatabaseIdentifier).ConfigureAwait(false);
var tenant = await GetTenantAsync(_options.MaybeCorrectTenantId(tenantIdOrDatabaseIdentifier)).ConfigureAwait(false);
return tenant.Database;
}

Expand Down
2 changes: 1 addition & 1 deletion src/Marten/Storage/StaticMultiTenancy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class StaticMultiTenancy: Tenancy, ITenancy, IStaticMultiTenancy
public StaticMultiTenancy(INpgsqlDataSourceFactory dataSourceFactory, StoreOptions options): base(options)
{
_dataSourceFactory = dataSourceFactory;
Cleaner = new CompositeDocumentCleaner(this);
Cleaner = new CompositeDocumentCleaner(this, options);
}

public void Dispose()
Expand Down
2 changes: 2 additions & 0 deletions src/Marten/StoreOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ public StoreOptions()

public string MaybeCorrectTenantId(string tenantId)
{
if (tenantId == Marten.Storage.Tenancy.DefaultTenantId) return tenantId;

switch (TenantIdStyle)
{
case TenantIdStyle.CaseSensitive:
Expand Down
Loading

0 comments on commit 427dd83

Please sign in to comment.