From 8b8fe3b4bfe3cde73c4088387c42c5e83c504cee Mon Sep 17 00:00:00 2001 From: Rain Valentine Date: Thu, 2 Jan 2025 06:10:47 +0000 Subject: [PATCH 1/3] Replace dict with new hashtable: hash datatype Signed-off-by: Rain Valentine --- src/aof.c | 2 +- src/db.c | 69 +++++-------------- src/debug.c | 10 +-- src/defrag.c | 67 +++++++++--------- src/lazyfree.c | 6 +- src/module.c | 36 ++-------- src/object.c | 56 ++++++++------- src/rdb.c | 62 ++++++++--------- src/server.c | 33 +++++++++ src/server.h | 18 +++-- src/t_hash.c | 183 ++++++++++++++++++++++++------------------------- 11 files changed, 259 insertions(+), 283 deletions(-) diff --git a/src/aof.c b/src/aof.c index 5dc12db61e..b02661c5a3 100644 --- a/src/aof.c +++ b/src/aof.c @@ -1936,7 +1936,7 @@ static int rioWriteHashIteratorCursor(rio *r, hashTypeIterator *hi, int what) { return rioWriteBulkString(r, (char *)vstr, vlen); else return rioWriteBulkLongLong(r, vll); - } else if (hi->encoding == OBJ_ENCODING_HT) { + } else if (hi->encoding == OBJ_ENCODING_HASHTABLE) { sds value = hashTypeCurrentFromHashTable(hi, what); return rioWriteBulkString(r, value, sdslen(value)); } diff --git a/src/db.c b/src/db.c index 55ffe5da5a..4f78f3a157 100644 --- a/src/db.c +++ b/src/db.c @@ -979,39 +979,6 @@ void keysScanCallback(void *privdata, void *entry) { /* This callback is used by scanGenericCommand in order to collect elements * returned by the dictionary iterator into a list. */ -void dictScanCallback(void *privdata, const dictEntry *de) { - scanData *data = (scanData *)privdata; - list *keys = data->keys; - robj *o = data->o; - sds val = NULL; - sds key = NULL; - data->sampled++; - - /* This callback is only used for scanning elements within a key (hash - * fields, set elements, etc.) so o must be set here. */ - serverAssert(o != NULL); - - /* Filter element if it does not match the pattern. */ - sds keysds = dictGetKey(de); - if (data->pattern) { - if (!stringmatchlen(data->pattern, sdslen(data->pattern), keysds, sdslen(keysds), 0)) { - return; - } - } - - if (o->type == OBJ_HASH) { - key = keysds; - if (!data->only_keys) { - val = dictGetVal(de); - } - } else { - serverPanic("Type not handled in dict SCAN callback."); - } - - listAddNodeTail(keys, key); - if (val) listAddNodeTail(keys, val); -} - void hashtableScanCallback(void *privdata, void *entry) { scanData *data = (scanData *)privdata; sds val = NULL; @@ -1025,14 +992,21 @@ void hashtableScanCallback(void *privdata, void *entry) { * fields, set elements, etc.) so o must be set here. */ serverAssert(o != NULL); - /* get key */ + /* get key, value */ if (o->type == OBJ_SET) { key = (sds)entry; } else if (o->type == OBJ_ZSET) { zskiplistNode *node = (zskiplistNode *)entry; key = node->ele; + /* zset data is copied after filtering by key */ + } else if (o->type == OBJ_HASH) { + key = (sds)hashTypeEntryGetKey(entry); + if (!data->only_keys) { + hashTypeEntry *hash_entry = entry; + val = hash_entry->value; + } } else { - serverPanic("Type not handled in hashset SCAN callback."); + serverPanic("Type not handled in hashtable SCAN callback."); } /* Filter element if it does not match the pattern. */ @@ -1042,9 +1016,9 @@ void hashtableScanCallback(void *privdata, void *entry) { } } - if (o->type == OBJ_SET) { - /* no value, key used by reference */ - } else if (o->type == OBJ_ZSET) { + /* zset data must be copied. Do this after filtering to avoid unneeded + * allocations. */ + if (o->type == OBJ_ZSET) { /* zset data is copied */ zskiplistNode *node = (zskiplistNode *)entry; key = sdsdup(node->ele); @@ -1053,8 +1027,6 @@ void hashtableScanCallback(void *privdata, void *entry) { int len = ld2string(buf, sizeof(buf), node->score, LD_STR_AUTO); val = sdsnewlen(buf, len); } - } else { - serverPanic("Type not handled in hashset SCAN callback."); } listAddNodeTail(keys, key); @@ -1193,20 +1165,19 @@ void scanGenericCommand(client *c, robj *o, unsigned long long cursor) { * cursor to zero to signal the end of the iteration. */ /* Handle the case of kvstore, dict or hashtable. */ - dict *dict_table = NULL; - hashtable *hashtable_table = NULL; + hashtable *ht = NULL; int shallow_copied_list_items = 0; if (o == NULL) { shallow_copied_list_items = 1; } else if (o->type == OBJ_SET && o->encoding == OBJ_ENCODING_HASHTABLE) { - hashtable_table = o->ptr; + ht = o->ptr; shallow_copied_list_items = 1; - } else if (o->type == OBJ_HASH && o->encoding == OBJ_ENCODING_HT) { - dict_table = o->ptr; + } else if (o->type == OBJ_HASH && o->encoding == OBJ_ENCODING_HASHTABLE) { + ht = o->ptr; shallow_copied_list_items = 1; } else if (o->type == OBJ_ZSET && o->encoding == OBJ_ENCODING_SKIPLIST) { zset *zs = o->ptr; - hashtable_table = zs->ht; + ht = zs->ht; /* scanning ZSET allocates temporary strings even though it's a dict */ shallow_copied_list_items = 0; } @@ -1220,7 +1191,7 @@ void scanGenericCommand(client *c, robj *o, unsigned long long cursor) { } /* For main hash table scan or scannable data structure. */ - if (!o || dict_table || hashtable_table) { + if (!o || ht) { /* We set the max number of iterations to ten times the specified * COUNT, so if the hash table is in a pathological state (very * sparsely populated) we avoid to block too much time at the cost @@ -1260,10 +1231,8 @@ void scanGenericCommand(client *c, robj *o, unsigned long long cursor) { * If cursor is empty, we should try exploring next non-empty slot. */ if (o == NULL) { cursor = kvstoreScan(c->db->keys, cursor, onlydidx, keysScanCallback, NULL, &data); - } else if (dict_table) { - cursor = dictScan(dict_table, cursor, dictScanCallback, &data); } else { - cursor = hashtableScan(hashtable_table, cursor, hashtableScanCallback, &data); + cursor = hashtableScan(ht, cursor, hashtableScanCallback, &data); } } while (cursor && maxiterations-- && data.sampled < count); } else if (o->type == OBJ_SET) { diff --git a/src/debug.c b/src/debug.c index c80ff5af39..ebaee54f92 100644 --- a/src/debug.c +++ b/src/debug.c @@ -923,23 +923,17 @@ void debugCommand(client *c) { robj *o = objectCommandLookupOrReply(c, c->argv[2], shared.nokeyerr); if (o == NULL) return; - /* Get the dict reference from the object, if possible. */ - dict *d = NULL; + /* Get the hashtable reference from the object, if possible. */ hashtable *ht = NULL; switch (o->encoding) { case OBJ_ENCODING_SKIPLIST: { zset *zs = o->ptr; ht = zs->ht; } break; - case OBJ_ENCODING_HT: d = o->ptr; break; case OBJ_ENCODING_HASHTABLE: ht = o->ptr; break; } - if (d != NULL) { - char buf[4096]; - dictGetStats(buf, sizeof(buf), d, full); - addReplyVerbatim(c, buf, strlen(buf), "txt"); - } else if (ht != NULL) { + if (ht != NULL) { char buf[4096]; hashtableGetStats(buf, sizeof(buf), ht, full); addReplyVerbatim(c, buf, strlen(buf), "txt"); diff --git a/src/defrag.c b/src/defrag.c index 103730ee14..bf9632264b 100644 --- a/src/defrag.c +++ b/src/defrag.c @@ -373,13 +373,6 @@ void activeDefragSdsHashtableCallback(void *privdata, void *entry_ref) { if (new_sds != NULL) *sds_ref = new_sds; } -void activeDefragSdsHashtable(hashtable *ht) { - unsigned long cursor = 0; - do { - cursor = hashtableScanDefrag(ht, cursor, activeDefragSdsHashtableCallback, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); - } while (cursor != 0); -} - /* Defrag a list of ptr, sds or robj string values */ static void activeDefragQuickListNode(quicklist *ql, quicklistNode **node_ref) { quicklistNode *newnode, *node = *node_ref; @@ -481,26 +474,29 @@ static void scanHashtableCallbackCountScanned(void *privdata, void *elemref) { server.stat_active_defrag_scanned++; } -/* Used as dict scan callback when all the work is done in the dictDefragFunctions. */ -static void scanCallbackCountScanned(void *privdata, const dictEntry *de) { - UNUSED(privdata); - UNUSED(de); - server.stat_active_defrag_scanned++; -} - static void scanLaterSet(robj *ob, unsigned long *cursor) { if (ob->type != OBJ_SET || ob->encoding != OBJ_ENCODING_HASHTABLE) return; hashtable *ht = ob->ptr; *cursor = hashtableScanDefrag(ht, *cursor, activeDefragSdsHashtableCallback, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); } +static void activeDefragHashField(void *privdata, void *element_ref) { + UNUSED(privdata); + hashTypeEntry **entry_ref = (hashTypeEntry **)element_ref; + + /* defragment field */ + hashTypeEntry *new_entry = activeDefragAlloc(*entry_ref); + if (new_entry) *entry_ref = new_entry; + + /* defragment value */ + sds new_value = activeDefragSds((*entry_ref)->value); + if (new_value) (*entry_ref)->value = new_value; +} + static void scanLaterHash(robj *ob, unsigned long *cursor) { - if (ob->type != OBJ_HASH || ob->encoding != OBJ_ENCODING_HT) return; - dict *d = ob->ptr; - dictDefragFunctions defragfns = {.defragAlloc = activeDefragAlloc, - .defragKey = (dictDefragAllocFunction *)activeDefragSds, - .defragVal = (dictDefragAllocFunction *)activeDefragSds}; - *cursor = dictScanDefrag(d, *cursor, scanCallbackCountScanned, &defragfns, NULL); + if (ob->type != OBJ_HASH || ob->encoding != OBJ_ENCODING_HASHTABLE) return; + hashtable *ht = ob->ptr; + *cursor = hashtableScanDefrag(ht, *cursor, activeDefragHashField, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); } static void defragQuicklist(robj *ob) { @@ -538,15 +534,19 @@ static void defragZsetSkiplist(robj *ob) { } static void defragHash(robj *ob) { - dict *d, *newd; - serverAssert(ob->type == OBJ_HASH && ob->encoding == OBJ_ENCODING_HT); - d = ob->ptr; - if (dictSize(d) > server.active_defrag_max_scan_fields) + serverAssert(ob->type == OBJ_HASH && ob->encoding == OBJ_ENCODING_HASHTABLE); + hashtable *ht = ob->ptr; + if (hashtableSize(ht) > server.active_defrag_max_scan_fields) { defragLater(ob); - else - activeDefragSdsDict(d, DEFRAG_SDS_DICT_VAL_IS_SDS); - /* defrag the dict struct and tables */ - if ((newd = dictDefragTables(ob->ptr))) ob->ptr = newd; + } else { + unsigned long cursor = 0; + do { + cursor = hashtableScanDefrag(ht, cursor, activeDefragHashField, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); + } while (cursor != 0); + } + /* defrag the hashtable struct and tables */ + hashtable *new_hashtable = hashtableDefragTables(ht, activeDefragAlloc); + if (new_hashtable) ob->ptr = new_hashtable; } static void defragSet(robj *ob) { @@ -555,11 +555,14 @@ static void defragSet(robj *ob) { if (hashtableSize(ht) > server.active_defrag_max_scan_fields) { defragLater(ob); } else { - activeDefragSdsHashtable(ht); + unsigned long cursor = 0; + do { + cursor = hashtableScanDefrag(ht, cursor, activeDefragSdsHashtableCallback, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); + } while (cursor != 0); } /* defrag the hashtable struct and tables */ - hashtable *newHashtable = hashtableDefragTables(ht, activeDefragAlloc); - if (newHashtable) ob->ptr = newHashtable; + hashtable *new_hashtable = hashtableDefragTables(ht, activeDefragAlloc); + if (new_hashtable) ob->ptr = new_hashtable; } /* Defrag callback for radix tree iterator, called for each node, @@ -776,7 +779,7 @@ static void defragKey(defragKeysCtx *ctx, robj **elemref) { } else if (ob->type == OBJ_HASH) { if (ob->encoding == OBJ_ENCODING_LISTPACK) { if ((newzl = activeDefragAlloc(ob->ptr))) ob->ptr = newzl; - } else if (ob->encoding == OBJ_ENCODING_HT) { + } else if (ob->encoding == OBJ_ENCODING_HASHTABLE) { defragHash(ob); } else { serverPanic("Unknown hash encoding"); diff --git a/src/lazyfree.c b/src/lazyfree.c index c22d3da964..3b061ccd84 100644 --- a/src/lazyfree.c +++ b/src/lazyfree.c @@ -123,9 +123,9 @@ size_t lazyfreeGetFreeEffort(robj *key, robj *obj, int dbid) { } else if (obj->type == OBJ_ZSET && obj->encoding == OBJ_ENCODING_SKIPLIST) { zset *zs = obj->ptr; return zs->zsl->length; - } else if (obj->type == OBJ_HASH && obj->encoding == OBJ_ENCODING_HT) { - dict *ht = obj->ptr; - return dictSize(ht); + } else if (obj->type == OBJ_HASH && obj->encoding == OBJ_ENCODING_HASHTABLE) { + hashtable *ht = obj->ptr; + return hashtableSize(ht); } else if (obj->type == OBJ_STREAM) { size_t effort = 0; stream *s = obj->ptr; diff --git a/src/module.c b/src/module.c index 58555839f2..809fcfb8e2 100644 --- a/src/module.c +++ b/src/module.c @@ -11090,25 +11090,6 @@ typedef struct { ValkeyModuleScanKeyCB fn; } ScanKeyCBData; -static void moduleScanKeyDictCallback(void *privdata, const dictEntry *de) { - ScanKeyCBData *data = privdata; - sds key = dictGetKey(de); - robj *o = data->key->value; - robj *field = createStringObject(key, sdslen(key)); - robj *value = NULL; - - if (o->type == OBJ_HASH) { - sds val = dictGetVal(de); - value = createStringObject(val, sdslen(val)); - } else { - serverPanic("unexpected object type"); - } - - data->fn(data->key, field, value, data->user_data); - decrRefCount(field); - if (value) decrRefCount(value); -} - static void moduleScanKeyHashtableCallback(void *privdata, void *entry) { ScanKeyCBData *data = privdata; robj *o = data->key->value; @@ -11122,6 +11103,11 @@ static void moduleScanKeyHashtableCallback(void *privdata, void *entry) { zskiplistNode *node = (zskiplistNode *)entry; key = node->ele; value = createStringObjectFromLongDouble(node->score, 0); + } else if (o->type == OBJ_HASH) { + key = (sds)hashTypeEntryGetKey(entry); + hashTypeEntry *hash_entry = entry; + sds val = hash_entry->value; + value = createStringObject(val, sdslen(val)); } else { serverPanic("unexpected object type"); } @@ -11185,13 +11171,12 @@ int VM_ScanKey(ValkeyModuleKey *key, ValkeyModuleScanCursor *cursor, ValkeyModul errno = EINVAL; return 0; } - dict *d = NULL; hashtable *ht = NULL; robj *o = key->value; if (o->type == OBJ_SET) { if (o->encoding == OBJ_ENCODING_HASHTABLE) ht = o->ptr; } else if (o->type == OBJ_HASH) { - if (o->encoding == OBJ_ENCODING_HT) d = o->ptr; + if (o->encoding == OBJ_ENCODING_HASHTABLE) ht = o->ptr; } else if (o->type == OBJ_ZSET) { if (o->encoding == OBJ_ENCODING_SKIPLIST) ht = ((zset *)o->ptr)->ht; } else { @@ -11203,14 +11188,7 @@ int VM_ScanKey(ValkeyModuleKey *key, ValkeyModuleScanCursor *cursor, ValkeyModul return 0; } int ret = 1; - if (d) { - ScanKeyCBData data = {key, privdata, fn}; - cursor->cursor = dictScan(d, cursor->cursor, moduleScanKeyDictCallback, &data); - if (cursor->cursor == 0) { - cursor->done = 1; - ret = 0; - } - } else if (ht) { + if (ht) { ScanKeyCBData data = {key, privdata, fn}; cursor->cursor = hashtableScan(ht, cursor->cursor, moduleScanKeyHashtableCallback, &data); if (cursor->cursor == 0) { diff --git a/src/object.c b/src/object.c index 86eefe43a3..8c6bfa6cd8 100644 --- a/src/object.c +++ b/src/object.c @@ -530,7 +530,7 @@ void freeZsetObject(robj *o) { void freeHashObject(robj *o) { switch (o->encoding) { - case OBJ_ENCODING_HT: dictRelease((dict *)o->ptr); break; + case OBJ_ENCODING_HASHTABLE: hashtableRelease((hashtable *)o->ptr); break; case OBJ_ENCODING_LISTPACK: lpFree(o->ptr); break; default: serverPanic("Unknown hash encoding type"); break; } @@ -675,25 +675,26 @@ void dismissZsetObject(robj *o, size_t size_hint) { /* See dismissObject() */ void dismissHashObject(robj *o, size_t size_hint) { - if (o->encoding == OBJ_ENCODING_HT) { - dict *d = o->ptr; - serverAssert(dictSize(d) != 0); + if (o->encoding == OBJ_ENCODING_HASHTABLE) { + hashtable *ht = o->ptr; + serverAssert(hashtableSize(ht) != 0); /* We iterate all fields only when average field/value size is bigger than * a page size, and there's a high chance we'll actually dismiss something. */ - if (size_hint / dictSize(d) >= server.page_size) { - dictEntry *de; - dictIterator *di = dictGetIterator(d); - while ((de = dictNext(di)) != NULL) { + if (size_hint / hashtableSize(ht) >= server.page_size) { + hashtableIterator iter; + hashtableInitIterator(&iter, ht); + void *next; + while (hashtableNext(&iter, &next)) { /* Only dismiss values memory since the field size * usually is small. */ - dismissSds(dictGetVal(de)); + hashTypeEntry *entry = next; + UNUSED(entry); + dismissSds(entry->value); } - dictReleaseIterator(di); + hashtableResetIterator(&iter); } - /* Dismiss hash table memory. */ - dismissMemory(d->ht_table[0], DICTHT_SIZE(d->ht_size_exp[0]) * sizeof(dictEntry *)); - dismissMemory(d->ht_table[1], DICTHT_SIZE(d->ht_size_exp[1]) * sizeof(dictEntry *)); + dismissHashtable(ht); } else if (o->encoding == OBJ_ENCODING_LISTPACK) { dismissMemory(o->ptr, lpBytes((unsigned char *)o->ptr)); } else { @@ -1106,7 +1107,6 @@ char *strEncoding(int encoding) { switch (encoding) { case OBJ_ENCODING_RAW: return "raw"; case OBJ_ENCODING_INT: return "int"; - case OBJ_ENCODING_HT: return "hashtable"; case OBJ_ENCODING_HASHTABLE: return "hashtable"; case OBJ_ENCODING_QUICKLIST: return "quicklist"; case OBJ_ENCODING_LISTPACK: return "listpack"; @@ -1127,10 +1127,6 @@ char *strEncoding(int encoding) { * are checked and averaged to estimate the total size. */ #define OBJ_COMPUTE_SIZE_DEF_SAMPLES 5 /* Default sample size. */ size_t objectComputeSize(robj *key, robj *o, size_t sample_size, int dbid) { - sds ele, ele2; - dict *d; - dictIterator *di; - struct dictEntry *de; size_t asize = 0, elesize = 0, samples = 0; if (o->type == OBJ_STRING) { @@ -1202,19 +1198,21 @@ size_t objectComputeSize(robj *key, robj *o, size_t sample_size, int dbid) { } else if (o->type == OBJ_HASH) { if (o->encoding == OBJ_ENCODING_LISTPACK) { asize = sizeof(*o) + zmalloc_size(o->ptr); - } else if (o->encoding == OBJ_ENCODING_HT) { - d = o->ptr; - di = dictGetIterator(d); - asize = sizeof(*o) + sizeof(dict) + (sizeof(struct dictEntry *) * dictBuckets(d)); - while ((de = dictNext(di)) != NULL && samples < sample_size) { - ele = dictGetKey(de); - ele2 = dictGetVal(de); - elesize += sdsAllocSize(ele) + sdsAllocSize(ele2); - elesize += dictEntryMemUsage(de); + } else if (o->encoding == OBJ_ENCODING_HASHTABLE) { + hashtable *ht = o->ptr; + hashtableIterator iter; + hashtableInitIterator(&iter, ht); + void *next; + + asize = sizeof(*o) + hashtableMemUsage(ht); + while (hashtableNext(&iter, &next) && samples < sample_size) { + elesize += zmalloc_usable_size(next); + hashTypeEntry *entry = next; + elesize += sdsAllocSize(entry->value); samples++; } - dictReleaseIterator(di); - if (samples) asize += (double)elesize / samples * dictSize(d); + hashtableResetIterator(&iter); + if (samples) asize += (double)elesize / samples * hashtableSize(ht); } else { serverPanic("Unknown hash encoding"); } diff --git a/src/rdb.c b/src/rdb.c index 6a2ec78d71..59c97732e2 100644 --- a/src/rdb.c +++ b/src/rdb.c @@ -710,7 +710,7 @@ int rdbSaveObjectType(rio *rdb, robj *o) { case OBJ_HASH: if (o->encoding == OBJ_ENCODING_LISTPACK) return rdbSaveType(rdb, RDB_TYPE_HASH_LISTPACK); - else if (o->encoding == OBJ_ENCODING_HT) + else if (o->encoding == OBJ_ENCODING_HASHTABLE) return rdbSaveType(rdb, RDB_TYPE_HASH); else serverPanic("Unknown hash encoding"); @@ -950,32 +950,33 @@ ssize_t rdbSaveObject(rio *rdb, robj *o, robj *key, int dbid) { if ((n = rdbSaveRawString(rdb, o->ptr, l)) == -1) return -1; nwritten += n; - } else if (o->encoding == OBJ_ENCODING_HT) { - dictIterator *di = dictGetIterator(o->ptr); - dictEntry *de; + } else if (o->encoding == OBJ_ENCODING_HASHTABLE) { + hashtable *ht = o->ptr; - if ((n = rdbSaveLen(rdb, dictSize((dict *)o->ptr))) == -1) { - dictReleaseIterator(di); + if ((n = rdbSaveLen(rdb, hashtableSize(ht))) == -1) { return -1; } nwritten += n; - while ((de = dictNext(di)) != NULL) { - sds field = dictGetKey(de); - sds value = dictGetVal(de); + hashtableIterator iter; + hashtableInitIterator(&iter, ht); + void *next; + while (hashtableNext(&iter, &next)) { + sds field = (sds)hashTypeEntryGetKey(next); + sds value = ((hashTypeEntry *)next)->value; if ((n = rdbSaveRawString(rdb, (unsigned char *)field, sdslen(field))) == -1) { - dictReleaseIterator(di); + hashtableResetIterator(&iter); return -1; } nwritten += n; if ((n = rdbSaveRawString(rdb, (unsigned char *)value, sdslen(value))) == -1) { - dictReleaseIterator(di); + hashtableResetIterator(&iter); return -1; } nwritten += n; } - dictReleaseIterator(di); + hashtableResetIterator(&iter); } else { serverPanic("Unknown hash encoding"); } @@ -2063,7 +2064,6 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { } } else if (rdbtype == RDB_TYPE_HASH) { uint64_t len; - int ret; sds field, value; dict *dupSearchDict = NULL; @@ -2075,10 +2075,10 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { /* Too many entries? Use a hash table right from the start. */ if (len > server.hash_max_listpack_entries) - hashTypeConvert(o, OBJ_ENCODING_HT); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); else if (deep_integrity_validation) { /* In this mode, we need to guarantee that the server won't crash - * later when the ziplist is converted to a dict. + * later when the ziplist is converted to a hashtable. * Create a set (dict with no values) to for a dup search. * We can dismiss it as soon as we convert the ziplist to a hash. */ dupSearchDict = dictCreate(&hashDictType); @@ -2117,13 +2117,13 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { /* Convert to hash table if size threshold is exceeded */ if (sdslen(field) > server.hash_max_listpack_value || sdslen(value) > server.hash_max_listpack_value || !lpSafeToAdd(o->ptr, sdslen(field) + sdslen(value))) { - hashTypeConvert(o, OBJ_ENCODING_HT); - ret = dictAdd((dict *)o->ptr, field, value); - if (ret == DICT_ERR) { + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); + hashTypeEntry *entry = hashTypeCreateEntry(field, value); + sdsfree(field); + if (!hashtableAdd((hashtable *)o->ptr, entry)) { rdbReportCorruptRDB("Duplicate hash fields detected"); if (dupSearchDict) dictRelease(dupSearchDict); - sdsfree(value); - sdsfree(field); + freeHashTypeEntry(entry); decrRefCount(o); return NULL; } @@ -2145,16 +2145,16 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { dupSearchDict = NULL; } - if (o->encoding == OBJ_ENCODING_HT && len > DICT_HT_INITIAL_SIZE) { - if (dictTryExpand(o->ptr, len) != DICT_OK) { - rdbReportCorruptRDB("OOM in dictTryExpand %llu", (unsigned long long)len); + if (o->encoding == OBJ_ENCODING_HASHTABLE) { + if (!hashtableTryExpand(o->ptr, len)) { + rdbReportCorruptRDB("OOM in hashtableTryExpand %llu", (unsigned long long)len); decrRefCount(o); return NULL; } } /* Load remaining fields and values into the hash table */ - while (o->encoding == OBJ_ENCODING_HT && len > 0) { + while (o->encoding == OBJ_ENCODING_HASHTABLE && len > 0) { len--; /* Load encoded strings */ if ((field = rdbGenericLoadStringObject(rdb, RDB_LOAD_SDS, NULL)) == NULL) { @@ -2168,11 +2168,11 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { } /* Add pair to hash table */ - ret = dictAdd((dict *)o->ptr, field, value); - if (ret == DICT_ERR) { + hashTypeEntry *entry = hashTypeCreateEntry(field, value); // TODO rainval avoid extra allocation of field? + sdsfree(field); + if (!hashtableAdd((hashtable *)o->ptr, entry)) { rdbReportCorruptRDB("Duplicate hash fields detected"); - sdsfree(value); - sdsfree(field); + freeHashTypeEntry(entry); decrRefCount(o); return NULL; } @@ -2317,7 +2317,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { o->encoding = OBJ_ENCODING_LISTPACK; if (hashTypeLength(o) > server.hash_max_listpack_entries || maxlen > server.hash_max_listpack_value) { - hashTypeConvert(o, OBJ_ENCODING_HT); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); } } break; @@ -2445,7 +2445,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { } if (hashTypeLength(o) > server.hash_max_listpack_entries) - hashTypeConvert(o, OBJ_ENCODING_HT); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); else o->ptr = lpShrinkToFit(o->ptr); break; @@ -2466,7 +2466,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { goto emptykey; } - if (hashTypeLength(o) > server.hash_max_listpack_entries) hashTypeConvert(o, OBJ_ENCODING_HT); + if (hashTypeLength(o) > server.hash_max_listpack_entries) hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); break; default: /* totally unreachable */ diff --git a/src/server.c b/src/server.c index 5de0a123f2..afea3d00e9 100644 --- a/src/server.c +++ b/src/server.c @@ -624,6 +624,39 @@ hashtableType subcommandSetType = {.entryGetKey = hashtableSubcommandGetKey, .keyCompare = hashtableStringKeyCaseCompare, .instant_rehashing = 1}; +/* takes ownership of value, does not take ownership of field */ +hashTypeEntry *hashTypeCreateEntry(const sds field, sds value) { + size_t field_size = sdsAllocSize(field); + void *field_data = sdsAllocPtr(field); + + size_t total_size = sizeof(hashTypeEntry) + field_size; + hashTypeEntry *hf = zmalloc(total_size); + + hf->value = value; + hf->field_offset = field - (char *)field_data; + memcpy(hf->field_data, field_data, field_size); + return hf; +} + +const void *hashTypeEntryGetKey(const void *entry) { + const hashTypeEntry *hf = entry; + return hf->field_data + hf->field_offset; +} + +void freeHashTypeEntry(void *entry) { + hashTypeEntry *hf = entry; + sdsfree(hf->value); + zfree(hf); +} + +/* Hash type hash table (note that small hashes are represented with listpacks) */ +hashtableType hashHashtableType = { + .hashFunction = dictSdsHash, + .entryGetKey = hashTypeEntryGetKey, + .keyCompare = hashtableSdsKeyCompare, + .entryDestructor = freeHashTypeEntry, +}; + /* Hash type hash table (note that small hashes are represented with listpacks) */ dictType hashDictType = { dictSdsHash, /* hash function */ diff --git a/src/server.h b/src/server.h index 25c6ec7f4c..131b33ba67 100644 --- a/src/server.h +++ b/src/server.h @@ -711,7 +711,7 @@ typedef struct ValkeyModuleType moduleType; * is set to one of this fields for this object. */ #define OBJ_ENCODING_RAW 0 /* Raw representation */ #define OBJ_ENCODING_INT 1 /* Encoded as integer */ -#define OBJ_ENCODING_HT 2 /* Encoded as hash table */ +#define OBJ_ENCODING_HASHTABLE 2 /* Encoded as a hashtable */ #define OBJ_ENCODING_ZIPMAP 3 /* No longer used: old hash encoding. */ #define OBJ_ENCODING_LINKEDLIST 4 /* No longer used: old list encoding. */ #define OBJ_ENCODING_ZIPLIST 5 /* No longer used: old list/hash/zset encoding. */ @@ -721,7 +721,6 @@ typedef struct ValkeyModuleType moduleType; #define OBJ_ENCODING_QUICKLIST 9 /* Encoded as linked list of listpacks */ #define OBJ_ENCODING_STREAM 10 /* Encoded as a radix tree of listpacks */ #define OBJ_ENCODING_LISTPACK 11 /* Encoded as a listpack */ -#define OBJ_ENCODING_HASHTABLE 12 /* Encoded as a hashtable */ #define LRU_BITS 24 #define LRU_CLOCK_MAX ((1 << LRU_BITS) - 1) /* Max value of obj->lru */ @@ -2519,8 +2518,8 @@ typedef struct { unsigned char *fptr, *vptr; - dictIterator di; - dictEntry *de; + hashtableIterator iter; + void *next; } hashTypeIterator; #include "stream.h" /* Stream data type header file. */ @@ -2543,6 +2542,7 @@ extern hashtableType kvstoreKeysHashtableType; extern hashtableType kvstoreExpiresHashtableType; extern double R_Zero, R_PosInf, R_NegInf, R_Nan; extern dictType hashDictType; +extern hashtableType hashHashtableType; extern dictType stringSetDictType; extern dictType externalStringType; extern dictType sdsHashDictType; @@ -3232,6 +3232,15 @@ robj *setTypeDup(robj *o); #define HASH_SET_TAKE_VALUE (1 << 1) #define HASH_SET_COPY 0 +typedef struct { + sds value; + unsigned char field_offset; + char field_data[]; +} hashTypeEntry; +hashTypeEntry *hashTypeCreateEntry(sds field, sds value); +const void *hashTypeEntryGetKey(const void *entry); +void freeHashTypeEntry(void *entry); + void hashTypeConvert(robj *o, int enc); void hashTypeTryConversion(robj *subject, robj **argv, int start, int end); int hashTypeExists(robj *o, sds key); @@ -3246,7 +3255,6 @@ void hashTypeCurrentFromListpack(hashTypeIterator *hi, unsigned int *vlen, long long *vll); sds hashTypeCurrentFromHashTable(hashTypeIterator *hi, int what); -void hashTypeCurrentObject(hashTypeIterator *hi, int what, unsigned char **vstr, unsigned int *vlen, long long *vll); sds hashTypeCurrentObjectNewSds(hashTypeIterator *hi, int what); robj *hashTypeLookupWriteOrCreate(client *c, robj *key); robj *hashTypeGetValueObject(robj *o, sds field); diff --git a/src/t_hash.c b/src/t_hash.c index 1aa37968b7..dbb02a1583 100644 --- a/src/t_hash.c +++ b/src/t_hash.c @@ -48,7 +48,7 @@ void hashTypeTryConversion(robj *o, robj **argv, int start, int end) { * might over allocate memory if there are duplicates. */ size_t new_fields = (end - start + 1) / 2; if (new_fields > server.hash_max_listpack_entries) { - hashTypeConvert(o, OBJ_ENCODING_HT); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); dictExpand(o->ptr, new_fields); return; } @@ -57,12 +57,12 @@ void hashTypeTryConversion(robj *o, robj **argv, int start, int end) { if (!sdsEncodedObject(argv[i])) continue; size_t len = sdslen(argv[i]->ptr); if (len > server.hash_max_listpack_value) { - hashTypeConvert(o, OBJ_ENCODING_HT); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); return; } sum += len; } - if (!lpSafeToAdd(o->ptr, sum)) hashTypeConvert(o, OBJ_ENCODING_HT); + if (!lpSafeToAdd(o->ptr, sum)) hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); } /* Get the value from a listpack encoded hash, identified by field. @@ -95,13 +95,11 @@ int hashTypeGetFromListpack(robj *o, sds field, unsigned char **vstr, unsigned i * Returns NULL when the field cannot be found, otherwise the SDS value * is returned. */ sds hashTypeGetFromHashTable(robj *o, sds field) { - dictEntry *de; - - serverAssert(o->encoding == OBJ_ENCODING_HT); - - de = dictFind(o->ptr, field); - if (de == NULL) return NULL; - return dictGetVal(de); + serverAssert(o->encoding == OBJ_ENCODING_HASHTABLE); + void *entry; + if (!hashtableFind(o->ptr, field, &entry)) return NULL; + hashTypeEntry *hf = entry; + return hf->value; } /* Higher level function of hashTypeGet*() that returns the hash value @@ -117,9 +115,9 @@ int hashTypeGetValue(robj *o, sds field, unsigned char **vstr, unsigned int *vle if (o->encoding == OBJ_ENCODING_LISTPACK) { *vstr = NULL; if (hashTypeGetFromListpack(o, field, vstr, vlen, vll) == 0) return C_OK; - } else if (o->encoding == OBJ_ENCODING_HT) { - sds value; - if ((value = hashTypeGetFromHashTable(o, field)) != NULL) { + } else if (o->encoding == OBJ_ENCODING_HASHTABLE) { + sds value = hashTypeGetFromHashTable(o, field); + if (value != NULL) { *vstr = (unsigned char *)value; *vlen = sdslen(value); return C_OK; @@ -199,7 +197,7 @@ int hashTypeSet(robj *o, sds field, sds value, int flags) { * hashTypeTryConversion, so this check will be a NOP. */ if (o->encoding == OBJ_ENCODING_LISTPACK) { if (sdslen(field) > server.hash_max_listpack_value || sdslen(value) > server.hash_max_listpack_value) - hashTypeConvert(o, OBJ_ENCODING_HT); + hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); } if (o->encoding == OBJ_ENCODING_LISTPACK) { @@ -228,10 +226,10 @@ int hashTypeSet(robj *o, sds field, sds value, int flags) { o->ptr = zl; /* Check if the listpack needs to be converted to a hash table */ - if (hashTypeLength(o) > server.hash_max_listpack_entries) hashTypeConvert(o, OBJ_ENCODING_HT); - } else if (o->encoding == OBJ_ENCODING_HT) { - dict *ht = o->ptr; - dictEntry *de, *existing; + if (hashTypeLength(o) > server.hash_max_listpack_entries) hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); + } else if (o->encoding == OBJ_ENCODING_HASHTABLE) { + hashtable *ht = o->ptr; + sds v; if (flags & HASH_SET_TAKE_VALUE) { v = value; @@ -239,17 +237,18 @@ int hashTypeSet(robj *o, sds field, sds value, int flags) { } else { v = sdsdup(value); } - de = dictAddRaw(ht, field, &existing); - if (de) { - dictSetVal(ht, de, v); - if (flags & HASH_SET_TAKE_FIELD) { - field = NULL; - } else { - dictSetKey(ht, de, sdsdup(field)); - } + + hashtablePosition position; + void *existing; + if (hashtableFindPositionForInsert(ht, field, &position, &existing)) { + /* does not exist yet */ + hashTypeEntry *entry = hashTypeCreateEntry(field, v); + hashtableInsertAtPosition(ht, entry, &position); } else { - sdsfree(dictGetVal(existing)); - dictSetVal(ht, existing, v); + /* exists: replace value */ + hashTypeEntry *entry = existing; + sdsfree(entry->value); + entry->value = v; update = 1; } } else { @@ -282,11 +281,9 @@ int hashTypeDelete(robj *o, sds field) { deleted = 1; } } - } else if (o->encoding == OBJ_ENCODING_HT) { - if (dictDelete((dict *)o->ptr, field) == C_OK) { - deleted = 1; - } - + } else if (o->encoding == OBJ_ENCODING_HASHTABLE) { + hashtable *ht = o->ptr; + deleted = hashtableDelete(ht, field); } else { serverPanic("Unknown hash encoding"); } @@ -295,16 +292,15 @@ int hashTypeDelete(robj *o, sds field) { /* Return the number of elements in a hash. */ unsigned long hashTypeLength(const robj *o) { - unsigned long length = ULONG_MAX; - - if (o->encoding == OBJ_ENCODING_LISTPACK) { - length = lpLength(o->ptr) / 2; - } else if (o->encoding == OBJ_ENCODING_HT) { - length = dictSize((const dict *)o->ptr); - } else { + switch (o->encoding) { + case OBJ_ENCODING_LISTPACK: + return lpLength(o->ptr) / 2; + case OBJ_ENCODING_HASHTABLE: + return hashtableSize((const hashtable *)o->ptr); + default: serverPanic("Unknown hash encoding"); + return ULONG_MAX; } - return length; } void hashTypeInitIterator(robj *subject, hashTypeIterator *hi) { @@ -314,15 +310,15 @@ void hashTypeInitIterator(robj *subject, hashTypeIterator *hi) { if (hi->encoding == OBJ_ENCODING_LISTPACK) { hi->fptr = NULL; hi->vptr = NULL; - } else if (hi->encoding == OBJ_ENCODING_HT) { - dictInitIterator(&hi->di, subject->ptr); + } else if (hi->encoding == OBJ_ENCODING_HASHTABLE) { + hashtableInitIterator(&hi->iter, subject->ptr); } else { serverPanic("Unknown hash encoding"); } } void hashTypeResetIterator(hashTypeIterator *hi) { - if (hi->encoding == OBJ_ENCODING_HT) dictResetIterator(&hi->di); + if (hi->encoding == OBJ_ENCODING_HASHTABLE) hashtableResetIterator(&hi->iter); } /* Move to the next entry in the hash. Return C_OK when the next entry @@ -354,8 +350,8 @@ int hashTypeNext(hashTypeIterator *hi) { /* fptr, vptr now point to the first or next pair */ hi->fptr = fptr; hi->vptr = vptr; - } else if (hi->encoding == OBJ_ENCODING_HT) { - if ((hi->de = dictNext(&hi->di)) == NULL) return C_ERR; + } else if (hi->encoding == OBJ_ENCODING_HASHTABLE) { + if (!hashtableNext(&hi->iter, &hi->next)) return C_ERR; } else { serverPanic("Unknown hash encoding"); } @@ -382,12 +378,14 @@ void hashTypeCurrentFromListpack(hashTypeIterator *hi, * encoded as a hash table. Prototype is similar to * `hashTypeGetFromHashTable`. */ sds hashTypeCurrentFromHashTable(hashTypeIterator *hi, int what) { - serverAssert(hi->encoding == OBJ_ENCODING_HT); + serverAssert(hi->encoding == OBJ_ENCODING_HASHTABLE); if (what & OBJ_HASH_KEY) { - return dictGetKey(hi->de); + const void *key = hashTypeEntryGetKey(hi->next); + return (sds)key; } else { - return dictGetVal(hi->de); + hashTypeEntry *field = hi->next; + return field->value; } } @@ -401,11 +399,11 @@ sds hashTypeCurrentFromHashTable(hashTypeIterator *hi, int what) { * If *vll is populated *vstr is set to NULL, so the caller * can always check the function return by checking the return value * type checking if vstr == NULL. */ -void hashTypeCurrentObject(hashTypeIterator *hi, int what, unsigned char **vstr, unsigned int *vlen, long long *vll) { +static void hashTypeCurrentObject(hashTypeIterator *hi, int what, unsigned char **vstr, unsigned int *vlen, long long *vll) { if (hi->encoding == OBJ_ENCODING_LISTPACK) { *vstr = NULL; hashTypeCurrentFromListpack(hi, what, vstr, vlen, vll); - } else if (hi->encoding == OBJ_ENCODING_HT) { + } else if (hi->encoding == OBJ_ENCODING_HASHTABLE) { sds ele = hashTypeCurrentFromHashTable(hi, what); *vstr = (unsigned char *)ele; *vlen = sdslen(ele); @@ -444,26 +442,22 @@ void hashTypeConvertListpack(robj *o, int enc) { if (enc == OBJ_ENCODING_LISTPACK) { /* Nothing to do... */ - } else if (enc == OBJ_ENCODING_HT) { + } else if (enc == OBJ_ENCODING_HASHTABLE) { hashTypeIterator hi; - dict *dict; - int ret; - hashTypeInitIterator(o, &hi); - dict = dictCreate(&hashDictType); + hashtable *ht = hashtableCreate(&hashHashtableType); - /* Presize the dict to avoid rehashing */ - dictExpand(dict, hashTypeLength(o)); + /* Presize the hashtable to avoid rehashing */ + hashtableExpand(ht, hashTypeLength(o)); + hashTypeInitIterator(o, &hi); while (hashTypeNext(&hi) != C_ERR) { - sds key, value; - - key = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_KEY); - value = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_VALUE); - ret = dictAdd(dict, key, value); - if (ret != DICT_OK) { - sdsfree(key); - sdsfree(value); /* Needed for gcc ASAN */ + sds key = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_KEY); // TODO rainval don't copy twice - here and creation of entry + sds value = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_VALUE); + hashTypeEntry *field = hashTypeCreateEntry(key, value); + sdsfree(key); + if (!hashtableAdd(ht, field)) { + freeHashTypeEntry(field); hashTypeResetIterator(&hi); /* Needed for gcc ASAN */ serverLogHexDump(LL_WARNING, "listpack with dup elements dump", o->ptr, lpBytes(o->ptr)); serverPanic("Listpack corruption detected"); @@ -471,8 +465,8 @@ void hashTypeConvertListpack(robj *o, int enc) { } hashTypeResetIterator(&hi); zfree(o->ptr); - o->encoding = OBJ_ENCODING_HT; - o->ptr = dict; + o->encoding = OBJ_ENCODING_HASHTABLE; + o->ptr = ht; } else { serverPanic("Unknown hash encoding"); } @@ -481,7 +475,7 @@ void hashTypeConvertListpack(robj *o, int enc) { void hashTypeConvert(robj *o, int enc) { if (o->encoding == OBJ_ENCODING_LISTPACK) { hashTypeConvertListpack(o, enc); - } else if (o->encoding == OBJ_ENCODING_HT) { + } else if (o->encoding == OBJ_ENCODING_HASHTABLE) { serverPanic("Not implemented"); } else { serverPanic("Unknown hash encoding"); @@ -506,27 +500,24 @@ robj *hashTypeDup(robj *o) { memcpy(new_zl, zl, sz); hobj = createObject(OBJ_HASH, new_zl); hobj->encoding = OBJ_ENCODING_LISTPACK; - } else if (o->encoding == OBJ_ENCODING_HT) { - dict *d = dictCreate(&hashDictType); - dictExpand(d, dictSize((const dict *)o->ptr)); + } else if (o->encoding == OBJ_ENCODING_HASHTABLE) { + hashtable *ht = hashtableCreate(&hashHashtableType); + hashtableExpand(ht, hashtableSize((const hashtable *)o->ptr)); hashTypeInitIterator(o, &hi); while (hashTypeNext(&hi) != C_ERR) { - sds field, value; - sds newfield, newvalue; /* Extract a field-value pair from an original hash object.*/ - field = hashTypeCurrentFromHashTable(&hi, OBJ_HASH_KEY); - value = hashTypeCurrentFromHashTable(&hi, OBJ_HASH_VALUE); - newfield = sdsdup(field); - newvalue = sdsdup(value); + sds field = hashTypeCurrentFromHashTable(&hi, OBJ_HASH_KEY); + sds value = hashTypeCurrentFromHashTable(&hi, OBJ_HASH_VALUE); /* Add a field-value pair to a new hash object. */ - dictAdd(d, newfield, newvalue); + hashTypeEntry *entry = hashTypeCreateEntry(field, sdsdup(value)); + hashtableAdd(ht, entry); } hashTypeResetIterator(&hi); - hobj = createObject(OBJ_HASH, d); - hobj->encoding = OBJ_ENCODING_HT; + hobj = createObject(OBJ_HASH, ht); + hobj->encoding = OBJ_ENCODING_HASHTABLE; } else { serverPanic("Unknown hash encoding"); } @@ -550,16 +541,17 @@ void hashReplyFromListpackEntry(client *c, listpackEntry *e) { * 'key' and 'val' will be set to hold the element. * The memory in them is not to be freed or modified by the caller. * 'val' can be NULL in which case it's not extracted. */ -void hashTypeRandomElement(robj *hashobj, unsigned long hashsize, listpackEntry *key, listpackEntry *val) { - if (hashobj->encoding == OBJ_ENCODING_HT) { - dictEntry *de = dictGetFairRandomKey(hashobj->ptr); - sds s = dictGetKey(de); +static void hashTypeRandomElement(robj *hashobj, unsigned long hashsize, listpackEntry *key, listpackEntry *val) { + if (hashobj->encoding == OBJ_ENCODING_HASHTABLE) { + void *entry; + hashtableFairRandomEntry(hashobj->ptr, &entry); + const unsigned char *s = hashTypeEntryGetKey(entry); key->sval = (unsigned char *)s; - key->slen = sdslen(s); + key->slen = sdslen((sds)s); if (val) { - sds s = dictGetVal(de); - val->sval = (unsigned char *)s; - val->slen = sdslen(s); + hashTypeEntry *field = entry; + val->sval = (unsigned char *)field->value; + val->slen = sdslen(field->value); } } else if (hashobj->encoding == OBJ_ENCODING_LISTPACK) { lpRandomPair(hashobj->ptr, hashsize, key, val); @@ -799,7 +791,7 @@ static void addHashIteratorCursorToReply(writePreparedClient *wpc, hashTypeItera addWritePreparedReplyBulkCBuffer(wpc, vstr, vlen); else addWritePreparedReplyBulkLongLong(wpc, vll); - } else if (hi->encoding == OBJ_ENCODING_HT) { + } else if (hi->encoding == OBJ_ENCODING_HASHTABLE) { sds value = hashTypeCurrentFromHashTable(hi, what); addWritePreparedReplyBulkCBuffer(wpc, value, sdslen(value)); } else { @@ -933,12 +925,13 @@ void hrandfieldWithCountCommand(client *c, long l, int withvalues) { addWritePreparedReplyArrayLen(wpc, count * 2); else addWritePreparedReplyArrayLen(wpc, count); - if (hash->encoding == OBJ_ENCODING_HT) { - sds key, value; + if (hash->encoding == OBJ_ENCODING_HASHTABLE) { while (count--) { - dictEntry *de = dictGetFairRandomKey(hash->ptr); - key = dictGetKey(de); - value = dictGetVal(de); + void *entry; + hashtableFairRandomEntry(hash->ptr, &entry); + hashTypeEntry *field = entry; + sds key = (sds)hashTypeEntryGetKey(entry); + sds value = field->value; if (withvalues && c->resp > 2) addWritePreparedReplyArrayLen(wpc, 2); addWritePreparedReplyBulkCBuffer(wpc, key, sdslen(key)); if (withvalues) addWritePreparedReplyBulkCBuffer(wpc, value, sdslen(value)); From bc7d761f206125800e90ddab75aa06e9bd170df5 Mon Sep 17 00:00:00 2001 From: Rain Valentine Date: Fri, 3 Jan 2025 22:09:29 +0000 Subject: [PATCH 2/3] make hashTypeEntry opaque and consistent naming: hash keys contain field/value pairs Signed-off-by: Rain Valentine --- src/aof.c | 4 +- src/db.c | 5 +- src/debug.c | 2 +- src/defrag.c | 13 ++-- src/module.c | 5 +- src/object.c | 10 +-- src/rdb.c | 6 +- src/server.c | 33 +++------ src/server.h | 16 ++--- src/t_hash.c | 194 ++++++++++++++++++++++++++++++++------------------- 10 files changed, 157 insertions(+), 131 deletions(-) diff --git a/src/aof.c b/src/aof.c index b02661c5a3..024cdb2771 100644 --- a/src/aof.c +++ b/src/aof.c @@ -1922,7 +1922,7 @@ int rewriteSortedSetObject(rio *r, robj *key, robj *o) { /* Write either the key or the value of the currently selected item of a hash. * The 'hi' argument passes a valid hash iterator. * The 'what' filed specifies if to write a key or a value and can be - * either OBJ_HASH_KEY or OBJ_HASH_VALUE. + * either OBJ_HASH_FIELD or OBJ_HASH_VALUE. * * The function returns 0 on error, non-zero on success. */ static int rioWriteHashIteratorCursor(rio *r, hashTypeIterator *hi, int what) { @@ -1963,7 +1963,7 @@ int rewriteHashObject(rio *r, robj *key, robj *o) { } } - if (!rioWriteHashIteratorCursor(r, &hi, OBJ_HASH_KEY) || !rioWriteHashIteratorCursor(r, &hi, OBJ_HASH_VALUE)) { + if (!rioWriteHashIteratorCursor(r, &hi, OBJ_HASH_FIELD) || !rioWriteHashIteratorCursor(r, &hi, OBJ_HASH_VALUE)) { hashTypeResetIterator(&hi); return 0; } diff --git a/src/db.c b/src/db.c index 4f78f3a157..94074bf668 100644 --- a/src/db.c +++ b/src/db.c @@ -1000,10 +1000,9 @@ void hashtableScanCallback(void *privdata, void *entry) { key = node->ele; /* zset data is copied after filtering by key */ } else if (o->type == OBJ_HASH) { - key = (sds)hashTypeEntryGetKey(entry); + key = hashTypeEntryGetField(entry); if (!data->only_keys) { - hashTypeEntry *hash_entry = entry; - val = hash_entry->value; + val = hashTypeEntryGetValue(entry); } } else { serverPanic("Type not handled in hashtable SCAN callback."); diff --git a/src/debug.c b/src/debug.c index ebaee54f92..915e0c264d 100644 --- a/src/debug.c +++ b/src/debug.c @@ -231,7 +231,7 @@ void xorObjectDigest(serverDb *db, robj *keyobj, unsigned char *digest, robj *o) sds sdsele; memset(eledigest, 0, 20); - sdsele = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_KEY); + sdsele = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_FIELD); mixDigest(eledigest, sdsele, sdslen(sdsele)); sdsfree(sdsele); sdsele = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_VALUE); diff --git a/src/defrag.c b/src/defrag.c index bf9632264b..4d62360201 100644 --- a/src/defrag.c +++ b/src/defrag.c @@ -480,23 +480,18 @@ static void scanLaterSet(robj *ob, unsigned long *cursor) { *cursor = hashtableScanDefrag(ht, *cursor, activeDefragSdsHashtableCallback, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); } -static void activeDefragHashField(void *privdata, void *element_ref) { +static void activeDefragHashTypeEntry(void *privdata, void *element_ref) { UNUSED(privdata); hashTypeEntry **entry_ref = (hashTypeEntry **)element_ref; - /* defragment field */ - hashTypeEntry *new_entry = activeDefragAlloc(*entry_ref); + hashTypeEntry *new_entry = hashTypeEntryDefrag(*entry_ref, activeDefragAlloc, activeDefragSds); if (new_entry) *entry_ref = new_entry; - - /* defragment value */ - sds new_value = activeDefragSds((*entry_ref)->value); - if (new_value) (*entry_ref)->value = new_value; } static void scanLaterHash(robj *ob, unsigned long *cursor) { if (ob->type != OBJ_HASH || ob->encoding != OBJ_ENCODING_HASHTABLE) return; hashtable *ht = ob->ptr; - *cursor = hashtableScanDefrag(ht, *cursor, activeDefragHashField, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); + *cursor = hashtableScanDefrag(ht, *cursor, activeDefragHashTypeEntry, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); } static void defragQuicklist(robj *ob) { @@ -541,7 +536,7 @@ static void defragHash(robj *ob) { } else { unsigned long cursor = 0; do { - cursor = hashtableScanDefrag(ht, cursor, activeDefragHashField, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); + cursor = hashtableScanDefrag(ht, cursor, activeDefragHashTypeEntry, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); } while (cursor != 0); } /* defrag the hashtable struct and tables */ diff --git a/src/module.c b/src/module.c index 809fcfb8e2..6d508e2965 100644 --- a/src/module.c +++ b/src/module.c @@ -11104,9 +11104,8 @@ static void moduleScanKeyHashtableCallback(void *privdata, void *entry) { key = node->ele; value = createStringObjectFromLongDouble(node->score, 0); } else if (o->type == OBJ_HASH) { - key = (sds)hashTypeEntryGetKey(entry); - hashTypeEntry *hash_entry = entry; - sds val = hash_entry->value; + key = hashTypeEntryGetField(entry); + sds val = hashTypeEntryGetValue(entry); value = createStringObject(val, sdslen(val)); } else { serverPanic("unexpected object type"); diff --git a/src/object.c b/src/object.c index 8c6bfa6cd8..b8200dd815 100644 --- a/src/object.c +++ b/src/object.c @@ -685,11 +685,7 @@ void dismissHashObject(robj *o, size_t size_hint) { hashtableInitIterator(&iter, ht); void *next; while (hashtableNext(&iter, &next)) { - /* Only dismiss values memory since the field size - * usually is small. */ - hashTypeEntry *entry = next; - UNUSED(entry); - dismissSds(entry->value); + dismissHashTypeEntry(next); } hashtableResetIterator(&iter); } @@ -1206,9 +1202,7 @@ size_t objectComputeSize(robj *key, robj *o, size_t sample_size, int dbid) { asize = sizeof(*o) + hashtableMemUsage(ht); while (hashtableNext(&iter, &next) && samples < sample_size) { - elesize += zmalloc_usable_size(next); - hashTypeEntry *entry = next; - elesize += sdsAllocSize(entry->value); + elesize += hashTypeEntryAllocSize(next); samples++; } hashtableResetIterator(&iter); diff --git a/src/rdb.c b/src/rdb.c index 59c97732e2..0bb5d7d45d 100644 --- a/src/rdb.c +++ b/src/rdb.c @@ -962,8 +962,8 @@ ssize_t rdbSaveObject(rio *rdb, robj *o, robj *key, int dbid) { hashtableInitIterator(&iter, ht); void *next; while (hashtableNext(&iter, &next)) { - sds field = (sds)hashTypeEntryGetKey(next); - sds value = ((hashTypeEntry *)next)->value; + sds field = hashTypeEntryGetField(next); + sds value = hashTypeEntryGetValue(next); if ((n = rdbSaveRawString(rdb, (unsigned char *)field, sdslen(field))) == -1) { hashtableResetIterator(&iter); @@ -2168,7 +2168,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { } /* Add pair to hash table */ - hashTypeEntry *entry = hashTypeCreateEntry(field, value); // TODO rainval avoid extra allocation of field? + hashTypeEntry *entry = hashTypeCreateEntry(field, value); sdsfree(field); if (!hashtableAdd((hashtable *)o->ptr, entry)) { rdbReportCorruptRDB("Duplicate hash fields detected"); diff --git a/src/server.c b/src/server.c index afea3d00e9..b489409f92 100644 --- a/src/server.c +++ b/src/server.c @@ -624,37 +624,22 @@ hashtableType subcommandSetType = {.entryGetKey = hashtableSubcommandGetKey, .keyCompare = hashtableStringKeyCaseCompare, .instant_rehashing = 1}; -/* takes ownership of value, does not take ownership of field */ -hashTypeEntry *hashTypeCreateEntry(const sds field, sds value) { - size_t field_size = sdsAllocSize(field); - void *field_data = sdsAllocPtr(field); - - size_t total_size = sizeof(hashTypeEntry) + field_size; - hashTypeEntry *hf = zmalloc(total_size); - - hf->value = value; - hf->field_offset = field - (char *)field_data; - memcpy(hf->field_data, field_data, field_size); - return hf; -} - -const void *hashTypeEntryGetKey(const void *entry) { - const hashTypeEntry *hf = entry; - return hf->field_data + hf->field_offset; +/* Hash type hash table (note that small hashes are represented with listpacks) */ +const void *hashHashtableTypeGetKey(const void *entry) { + const hashTypeEntry *hash_entry = entry; + return (const void *)hashTypeEntryGetField(hash_entry); } -void freeHashTypeEntry(void *entry) { - hashTypeEntry *hf = entry; - sdsfree(hf->value); - zfree(hf); +void hashHashtableTypeDestructor(void *entry) { + hashTypeEntry *hash_entry = entry; + freeHashTypeEntry(hash_entry); } -/* Hash type hash table (note that small hashes are represented with listpacks) */ hashtableType hashHashtableType = { .hashFunction = dictSdsHash, - .entryGetKey = hashTypeEntryGetKey, + .entryGetKey = hashHashtableTypeGetKey, .keyCompare = hashtableSdsKeyCompare, - .entryDestructor = freeHashTypeEntry, + .entryDestructor = hashHashtableTypeDestructor, }; /* Hash type hash table (note that small hashes are represented with listpacks) */ diff --git a/src/server.h b/src/server.h index 131b33ba67..0725c6a880 100644 --- a/src/server.h +++ b/src/server.h @@ -2524,7 +2524,7 @@ typedef struct { #include "stream.h" /* Stream data type header file. */ -#define OBJ_HASH_KEY 1 +#define OBJ_HASH_FIELD 1 #define OBJ_HASH_VALUE 2 /*----------------------------------------------------------------------------- @@ -3232,14 +3232,14 @@ robj *setTypeDup(robj *o); #define HASH_SET_TAKE_VALUE (1 << 1) #define HASH_SET_COPY 0 -typedef struct { - sds value; - unsigned char field_offset; - char field_data[]; -} hashTypeEntry; +typedef struct hashTypeEntry hashTypeEntry; hashTypeEntry *hashTypeCreateEntry(sds field, sds value); -const void *hashTypeEntryGetKey(const void *entry); -void freeHashTypeEntry(void *entry); +sds hashTypeEntryGetField(const hashTypeEntry *entry); +sds hashTypeEntryGetValue(const hashTypeEntry *entry); +size_t hashTypeEntryAllocSize(hashTypeEntry *entry); +hashTypeEntry *hashTypeEntryDefrag(hashTypeEntry *entry, void *(*defragfn)(void *), sds (*sdsdefragfn)(sds)); +void dismissHashTypeEntry(hashTypeEntry *entry); +void freeHashTypeEntry(hashTypeEntry *entry); void hashTypeConvert(robj *o, int enc); void hashTypeTryConversion(robj *subject, robj **argv, int start, int end); diff --git a/src/t_hash.c b/src/t_hash.c index dbb02a1583..3af41ddddb 100644 --- a/src/t_hash.c +++ b/src/t_hash.c @@ -30,6 +30,65 @@ #include "server.h" #include +struct hashTypeEntry { + sds value; + unsigned char field_offset; + unsigned char field_data[]; +}; + +/* takes ownership of value, does not take ownership of field */ +hashTypeEntry *hashTypeCreateEntry(const sds field, sds value) { + size_t field_size = sdscopytobuffer(NULL, 0, field, NULL); + + size_t total_size = sizeof(hashTypeEntry) + field_size; + hashTypeEntry *entry = zmalloc(total_size); + + entry->value = value; + sdscopytobuffer(entry->field_data, field_size, field, &entry->field_offset); + return entry; +} + +sds hashTypeEntryGetField(const hashTypeEntry *entry) { + const unsigned char *field = entry->field_data + entry->field_offset; + return (sds)field; +} + +sds hashTypeEntryGetValue(const hashTypeEntry *entry) { + return entry->value; +} + +/* frees previous value, takes ownership of new value */ +static void hashTypeEntryReplaceValue(hashTypeEntry *entry, sds value) { + sdsfree(entry->value); + entry->value = value; +} + +size_t hashTypeEntryAllocSize(hashTypeEntry *entry) { + size_t size = zmalloc_usable_size(entry); + size += sdsAllocSize(entry->value); + return size; +} + +hashTypeEntry *hashTypeEntryDefrag(hashTypeEntry *entry, void *(*defragfn)(void *), sds (*sdsdefragfn)(sds)) { + hashTypeEntry *new_entry = defragfn(entry); + if (new_entry) entry = new_entry; + + sds new_value = sdsdefragfn(entry->value); + if (new_value) entry->value = new_value; + + return entry; +} + +void dismissHashTypeEntry(hashTypeEntry *entry) { + /* Only dismiss values memory since the field size usually is small. */ + dismissSds(entry->value); +} + +void freeHashTypeEntry(hashTypeEntry *entry) { + sdsfree(entry->value); + zfree(entry); +} + /*----------------------------------------------------------------------------- * Hash type API *----------------------------------------------------------------------------*/ @@ -49,7 +108,7 @@ void hashTypeTryConversion(robj *o, robj **argv, int start, int end) { size_t new_fields = (end - start + 1) / 2; if (new_fields > server.hash_max_listpack_entries) { hashTypeConvert(o, OBJ_ENCODING_HASHTABLE); - dictExpand(o->ptr, new_fields); + hashtableExpand(o->ptr, new_fields); return; } @@ -96,10 +155,9 @@ int hashTypeGetFromListpack(robj *o, sds field, unsigned char **vstr, unsigned i * is returned. */ sds hashTypeGetFromHashTable(robj *o, sds field) { serverAssert(o->encoding == OBJ_ENCODING_HASHTABLE); - void *entry; - if (!hashtableFind(o->ptr, field, &entry)) return NULL; - hashTypeEntry *hf = entry; - return hf->value; + void *found_element; + if (!hashtableFind(o->ptr, field, &found_element)) return NULL; + return hashTypeEntryGetValue(found_element); } /* Higher level function of hashTypeGet*() that returns the hash value @@ -171,7 +229,7 @@ int hashTypeExists(robj *o, sds field) { /* Add a new field, overwrite the old with the new value if it already exists. * Return 0 on insert and 1 on update. * - * By default, the key and value SDS strings are copied if needed, so the + * By default, the field and value SDS strings are copied if needed, so the * caller retains ownership of the strings passed. However this behavior * can be effected by passing appropriate flags (possibly bitwise OR-ed): * @@ -246,9 +304,7 @@ int hashTypeSet(robj *o, sds field, sds value, int flags) { hashtableInsertAtPosition(ht, entry, &position); } else { /* exists: replace value */ - hashTypeEntry *entry = existing; - sdsfree(entry->value); - entry->value = v; + hashTypeEntryReplaceValue(existing, v); update = 1; } } else { @@ -275,7 +331,7 @@ int hashTypeDelete(robj *o, sds field) { if (fptr != NULL) { fptr = lpFind(zl, fptr, (unsigned char *)field, sdslen(field), 1); if (fptr != NULL) { - /* Delete both of the key and the value. */ + /* Delete both field and value. */ zl = lpDeleteRangeWithEntry(zl, &fptr, 2); o->ptr = zl; deleted = 1; @@ -367,7 +423,7 @@ void hashTypeCurrentFromListpack(hashTypeIterator *hi, long long *vll) { serverAssert(hi->encoding == OBJ_ENCODING_LISTPACK); - if (what & OBJ_HASH_KEY) { + if (what & OBJ_HASH_FIELD) { *vstr = lpGetValue(hi->fptr, vlen, vll); } else { *vstr = lpGetValue(hi->vptr, vlen, vll); @@ -380,12 +436,10 @@ void hashTypeCurrentFromListpack(hashTypeIterator *hi, sds hashTypeCurrentFromHashTable(hashTypeIterator *hi, int what) { serverAssert(hi->encoding == OBJ_ENCODING_HASHTABLE); - if (what & OBJ_HASH_KEY) { - const void *key = hashTypeEntryGetKey(hi->next); - return (sds)key; + if (what & OBJ_HASH_FIELD) { + return hashTypeEntryGetField(hi->next); } else { - hashTypeEntry *field = hi->next; - return field->value; + return hashTypeEntryGetValue(hi->next); } } @@ -412,7 +466,7 @@ static void hashTypeCurrentObject(hashTypeIterator *hi, int what, unsigned char } } -/* Return the key or value at the current iterator position as a new +/* Return the field or value at the current iterator position as a new * SDS string. */ sds hashTypeCurrentObjectNewSds(hashTypeIterator *hi, int what) { unsigned char *vstr; @@ -452,12 +506,12 @@ void hashTypeConvertListpack(robj *o, int enc) { hashTypeInitIterator(o, &hi); while (hashTypeNext(&hi) != C_ERR) { - sds key = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_KEY); // TODO rainval don't copy twice - here and creation of entry + sds field = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_FIELD); sds value = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_VALUE); - hashTypeEntry *field = hashTypeCreateEntry(key, value); - sdsfree(key); - if (!hashtableAdd(ht, field)) { - freeHashTypeEntry(field); + hashTypeEntry *entry = hashTypeCreateEntry(field, value); + sdsfree(field); + if (!hashtableAdd(ht, entry)) { + freeHashTypeEntry(entry); hashTypeResetIterator(&hi); /* Needed for gcc ASAN */ serverLogHexDump(LL_WARNING, "listpack with dup elements dump", o->ptr, lpBytes(o->ptr)); serverPanic("Listpack corruption detected"); @@ -507,7 +561,7 @@ robj *hashTypeDup(robj *o) { hashTypeInitIterator(o, &hi); while (hashTypeNext(&hi) != C_ERR) { /* Extract a field-value pair from an original hash object.*/ - sds field = hashTypeCurrentFromHashTable(&hi, OBJ_HASH_KEY); + sds field = hashTypeCurrentFromHashTable(&hi, OBJ_HASH_FIELD); sds value = hashTypeCurrentFromHashTable(&hi, OBJ_HASH_VALUE); /* Add a field-value pair to a new hash object. */ @@ -538,23 +592,24 @@ void hashReplyFromListpackEntry(client *c, listpackEntry *e) { } /* Return random element from a non empty hash. - * 'key' and 'val' will be set to hold the element. + * 'field' and 'val' will be set to hold the element. * The memory in them is not to be freed or modified by the caller. * 'val' can be NULL in which case it's not extracted. */ -static void hashTypeRandomElement(robj *hashobj, unsigned long hashsize, listpackEntry *key, listpackEntry *val) { +static void hashTypeRandomElement(robj *hashobj, unsigned long hashsize, listpackEntry *field, listpackEntry *val) { if (hashobj->encoding == OBJ_ENCODING_HASHTABLE) { void *entry; hashtableFairRandomEntry(hashobj->ptr, &entry); - const unsigned char *s = hashTypeEntryGetKey(entry); - key->sval = (unsigned char *)s; - key->slen = sdslen((sds)s); + sds sds_field = hashTypeEntryGetField(entry); + field->sval = (unsigned char *)sds_field; + field->slen = sdslen(sds_field); if (val) { - hashTypeEntry *field = entry; - val->sval = (unsigned char *)field->value; - val->slen = sdslen(field->value); + hashTypeEntry *hash_entry = entry; + sds sds_val = hashTypeEntryGetValue(hash_entry); + val->sval = (unsigned char *)sds_val; + val->slen = sdslen(sds_val); } } else if (hashobj->encoding == OBJ_ENCODING_LISTPACK) { - lpRandomPair(hashobj->ptr, hashsize, key, val); + lpRandomPair(hashobj->ptr, hashsize, field, val); } else { serverPanic("Unknown hash encoding"); } @@ -804,15 +859,15 @@ void genericHgetallCommand(client *c, int flags) { hashTypeIterator hi; int length, count = 0; - robj *emptyResp = (flags & OBJ_HASH_KEY && flags & OBJ_HASH_VALUE) ? shared.emptymap[c->resp] : shared.emptyarray; + robj *emptyResp = (flags & OBJ_HASH_FIELD && flags & OBJ_HASH_VALUE) ? shared.emptymap[c->resp] : shared.emptyarray; if ((o = lookupKeyReadOrReply(c, c->argv[1], emptyResp)) == NULL || checkType(c, o, OBJ_HASH)) return; writePreparedClient *wpc = prepareClientForFutureWrites(c); if (!wpc) return; - /* We return a map if the user requested keys and values, like in the + /* We return a map if the user requested fields and values, like in the * HGETALL case. Otherwise to use a flat array makes more sense. */ length = hashTypeLength(o); - if (flags & OBJ_HASH_KEY && flags & OBJ_HASH_VALUE) { + if (flags & OBJ_HASH_FIELD && flags & OBJ_HASH_VALUE) { addWritePreparedReplyMapLen(wpc, length); } else { addWritePreparedReplyArrayLen(wpc, length); @@ -820,8 +875,8 @@ void genericHgetallCommand(client *c, int flags) { hashTypeInitIterator(o, &hi); while (hashTypeNext(&hi) != C_ERR) { - if (flags & OBJ_HASH_KEY) { - addHashIteratorCursorToReply(wpc, &hi, OBJ_HASH_KEY); + if (flags & OBJ_HASH_FIELD) { + addHashIteratorCursorToReply(wpc, &hi, OBJ_HASH_FIELD); count++; } if (flags & OBJ_HASH_VALUE) { @@ -833,12 +888,12 @@ void genericHgetallCommand(client *c, int flags) { hashTypeResetIterator(&hi); /* Make sure we returned the right number of elements. */ - if (flags & OBJ_HASH_KEY && flags & OBJ_HASH_VALUE) count /= 2; + if (flags & OBJ_HASH_FIELD && flags & OBJ_HASH_VALUE) count /= 2; serverAssert(count == length); } void hkeysCommand(client *c) { - genericHgetallCommand(c, OBJ_HASH_KEY); + genericHgetallCommand(c, OBJ_HASH_FIELD); } void hvalsCommand(client *c) { @@ -846,7 +901,7 @@ void hvalsCommand(client *c) { } void hgetallCommand(client *c) { - genericHgetallCommand(c, OBJ_HASH_KEY | OBJ_HASH_VALUE); + genericHgetallCommand(c, OBJ_HASH_FIELD | OBJ_HASH_VALUE); } void hexistsCommand(client *c) { @@ -865,14 +920,14 @@ void hscanCommand(client *c) { scanGenericCommand(c, o, cursor); } -static void hrandfieldReplyWithListpack(writePreparedClient *wpc, unsigned int count, listpackEntry *keys, listpackEntry *vals) { +static void hrandfieldReplyWithListpack(writePreparedClient *wpc, unsigned int count, listpackEntry *fields, listpackEntry *vals) { client *c = (client *)wpc; for (unsigned long i = 0; i < count; i++) { if (vals && c->resp > 2) addWritePreparedReplyArrayLen(wpc, 2); - if (keys[i].sval) - addWritePreparedReplyBulkCBuffer(wpc, keys[i].sval, keys[i].slen); + if (fields[i].sval) + addWritePreparedReplyBulkCBuffer(wpc, fields[i].sval, fields[i].slen); else - addWritePreparedReplyBulkLongLong(wpc, keys[i].lval); + addWritePreparedReplyBulkLongLong(wpc, fields[i].lval); if (vals) { if (vals[i].sval) addWritePreparedReplyBulkCBuffer(wpc, vals[i].sval, vals[i].slen); @@ -929,29 +984,28 @@ void hrandfieldWithCountCommand(client *c, long l, int withvalues) { while (count--) { void *entry; hashtableFairRandomEntry(hash->ptr, &entry); - hashTypeEntry *field = entry; - sds key = (sds)hashTypeEntryGetKey(entry); - sds value = field->value; + sds field = hashTypeEntryGetField(entry); + sds value = hashTypeEntryGetValue(entry); if (withvalues && c->resp > 2) addWritePreparedReplyArrayLen(wpc, 2); - addWritePreparedReplyBulkCBuffer(wpc, key, sdslen(key)); + addWritePreparedReplyBulkCBuffer(wpc, field, sdslen(field)); if (withvalues) addWritePreparedReplyBulkCBuffer(wpc, value, sdslen(value)); if (c->flag.close_asap) break; } } else if (hash->encoding == OBJ_ENCODING_LISTPACK) { - listpackEntry *keys, *vals = NULL; + listpackEntry *fields, *vals = NULL; unsigned long limit, sample_count; limit = count > HRANDFIELD_RANDOM_SAMPLE_LIMIT ? HRANDFIELD_RANDOM_SAMPLE_LIMIT : count; - keys = zmalloc(sizeof(listpackEntry) * limit); + fields = zmalloc(sizeof(listpackEntry) * limit); if (withvalues) vals = zmalloc(sizeof(listpackEntry) * limit); while (count) { sample_count = count > limit ? limit : count; count -= sample_count; - lpRandomPairs(hash->ptr, sample_count, keys, vals); - hrandfieldReplyWithListpack(wpc, sample_count, keys, vals); + lpRandomPairs(hash->ptr, sample_count, fields, vals); + hrandfieldReplyWithListpack(wpc, sample_count, fields, vals); if (c->flag.close_asap) break; } - zfree(keys); + zfree(fields); zfree(vals); } return; @@ -972,7 +1026,7 @@ void hrandfieldWithCountCommand(client *c, long l, int withvalues) { hashTypeInitIterator(hash, &hi); while (hashTypeNext(&hi) != C_ERR) { if (withvalues && c->resp > 2) addWritePreparedReplyArrayLen(wpc, 2); - addHashIteratorCursorToReply(wpc, &hi, OBJ_HASH_KEY); + addHashIteratorCursorToReply(wpc, &hi, OBJ_HASH_FIELD); if (withvalues) addHashIteratorCursorToReply(wpc, &hi, OBJ_HASH_VALUE); } hashTypeResetIterator(&hi); @@ -988,12 +1042,12 @@ void hrandfieldWithCountCommand(client *c, long l, int withvalues) { * And it is inefficient to repeatedly pick one random element from a * listpack in CASE 4. So we use this instead. */ if (hash->encoding == OBJ_ENCODING_LISTPACK) { - listpackEntry *keys, *vals = NULL; - keys = zmalloc(sizeof(listpackEntry) * count); + listpackEntry *fields, *vals = NULL; + fields = zmalloc(sizeof(listpackEntry) * count); if (withvalues) vals = zmalloc(sizeof(listpackEntry) * count); - serverAssert(lpRandomPairsUnique(hash->ptr, count, keys, vals) == count); - hrandfieldReplyWithListpack(wpc, count, keys, vals); - zfree(keys); + serverAssert(lpRandomPairsUnique(hash->ptr, count, fields, vals) == count); + hrandfieldReplyWithListpack(wpc, count, fields, vals); + zfree(fields); zfree(vals); return; } @@ -1017,11 +1071,11 @@ void hrandfieldWithCountCommand(client *c, long l, int withvalues) { /* Add all the elements into the temporary dictionary. */ while ((hashTypeNext(&hi)) != C_ERR) { int ret = DICT_ERR; - sds key, value = NULL; + sds field, value = NULL; - key = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_KEY); + field = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_FIELD); if (withvalues) value = hashTypeCurrentObjectNewSds(&hi, OBJ_HASH_VALUE); - ret = dictAdd(d, key, value); + ret = dictAdd(d, field, value); serverAssert(ret == DICT_OK); } @@ -1044,10 +1098,10 @@ void hrandfieldWithCountCommand(client *c, long l, int withvalues) { dictEntry *de; di = dictGetIterator(d); while ((de = dictNext(di)) != NULL) { - sds key = dictGetKey(de); + sds field = dictGetKey(de); sds value = dictGetVal(de); if (withvalues && c->resp > 2) addWritePreparedReplyArrayLen(wpc, 2); - addWritePreparedReplyBulkSds(wpc, key); + addWritePreparedReplyBulkSds(wpc, field); if (withvalues) addWritePreparedReplyBulkSds(wpc, value); } @@ -1062,25 +1116,25 @@ void hrandfieldWithCountCommand(client *c, long l, int withvalues) { else { /* Hashtable encoding (generic implementation) */ unsigned long added = 0; - listpackEntry key, value; + listpackEntry field, value; dict *d = dictCreate(&hashDictType); dictExpand(d, count); while (added < count) { - hashTypeRandomElement(hash, size, &key, withvalues ? &value : NULL); + hashTypeRandomElement(hash, size, &field, withvalues ? &value : NULL); /* Try to add the object to the dictionary. If it already exists * free it, otherwise increment the number of objects we have * in the result dictionary. */ - sds skey = hashSdsFromListpackEntry(&key); - if (dictAdd(d, skey, NULL) != DICT_OK) { - sdsfree(skey); + sds sfield = hashSdsFromListpackEntry(&field); + if (dictAdd(d, sfield, NULL) != DICT_OK) { + sdsfree(sfield); continue; } added++; /* We can reply right away, so that we don't need to store the value in the dict. */ if (withvalues && c->resp > 2) addWritePreparedReplyArrayLen(wpc, 2); - hashReplyFromListpackEntry(c, &key); + hashReplyFromListpackEntry(c, &field); if (withvalues) hashReplyFromListpackEntry(c, &value); } From 70c9d6e372e503cf5c094d695fb5688b758c2014 Mon Sep 17 00:00:00 2001 From: Rain Valentine Date: Fri, 10 Jan 2025 23:50:26 +0000 Subject: [PATCH 3/3] Add comments, minor fixes Signed-off-by: Rain Valentine --- src/defrag.c | 1 + src/t_hash.c | 22 ++++++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/defrag.c b/src/defrag.c index 4d62360201..fb98da96c7 100644 --- a/src/defrag.c +++ b/src/defrag.c @@ -480,6 +480,7 @@ static void scanLaterSet(robj *ob, unsigned long *cursor) { *cursor = hashtableScanDefrag(ht, *cursor, activeDefragSdsHashtableCallback, NULL, activeDefragAlloc, HASHTABLE_SCAN_EMIT_REF); } +/* Hashtable scan callback for hash datatype */ static void activeDefragHashTypeEntry(void *privdata, void *element_ref) { UNUSED(privdata); hashTypeEntry **entry_ref = (hashTypeEntry **)element_ref; diff --git a/src/t_hash.c b/src/t_hash.c index 3af41ddddb..b6e6457bb6 100644 --- a/src/t_hash.c +++ b/src/t_hash.c @@ -30,6 +30,10 @@ #include "server.h" #include +/*----------------------------------------------------------------------------- + * Hash Entry API + *----------------------------------------------------------------------------*/ + struct hashTypeEntry { sds value; unsigned char field_offset; @@ -37,7 +41,7 @@ struct hashTypeEntry { }; /* takes ownership of value, does not take ownership of field */ -hashTypeEntry *hashTypeCreateEntry(const sds field, sds value) { +hashTypeEntry *hashTypeCreateEntry(sds field, sds value) { size_t field_size = sdscopytobuffer(NULL, 0, field, NULL); size_t total_size = sizeof(hashTypeEntry) + field_size; @@ -63,12 +67,21 @@ static void hashTypeEntryReplaceValue(hashTypeEntry *entry, sds value) { entry->value = value; } +/* Returns allocation size of hashTypeEntry and data owned by hashTypeEntry, + * even if not embedded in the same allocation. */ size_t hashTypeEntryAllocSize(hashTypeEntry *entry) { size_t size = zmalloc_usable_size(entry); size += sdsAllocSize(entry->value); return size; } +/* Defragments a hashtable entry (field-value pair) if needed, using the + * provided defrag functions. The defrag functions return NULL if the allocation + * was not moved, otherwise they return a pointer to the new memory location. + * A separate sds defrag function is needed because of the unique memory layout + * of sds strings. + * If the location of the hashTypeEntry changed we return the new location, + * otherwise we return NULL. */ hashTypeEntry *hashTypeEntryDefrag(hashTypeEntry *entry, void *(*defragfn)(void *), sds (*sdsdefragfn)(sds)) { hashTypeEntry *new_entry = defragfn(entry); if (new_entry) entry = new_entry; @@ -76,9 +89,11 @@ hashTypeEntry *hashTypeEntryDefrag(hashTypeEntry *entry, void *(*defragfn)(void sds new_value = sdsdefragfn(entry->value); if (new_value) entry->value = new_value; - return entry; + return new_entry; } +/* Used for releasing memory to OS to avoid unnecessary CoW. Called when we've + * forked and memory won't be used again. See zmadvise_dontneed() */ void dismissHashTypeEntry(hashTypeEntry *entry) { /* Only dismiss values memory since the field size usually is small. */ dismissSds(entry->value); @@ -603,8 +618,7 @@ static void hashTypeRandomElement(robj *hashobj, unsigned long hashsize, listpac field->sval = (unsigned char *)sds_field; field->slen = sdslen(sds_field); if (val) { - hashTypeEntry *hash_entry = entry; - sds sds_val = hashTypeEntryGetValue(hash_entry); + sds sds_val = hashTypeEntryGetValue(entry); val->sval = (unsigned char *)sds_val; val->slen = sdslen(sds_val); }