Effective parallel execution of inserts with the Python driver (2.x)

Hi TypeDB community,

It would be great if you could help me shed some light on my failed attempts at reducing the execution time of a bulk load when using the Python driver in TypeDB 2.29.0.

The initial approach consisted of a single process with batches, which takes around 8 minutes to load ~42000 insert queries. The concerned functions roughly look as follows:

def load(self, data: list[str], batch_size=100):
    amount = len(data)

    with TypeDB.core_driver(self.server_address) as driver:
        with driver.session(self.db_name, SessionType.DATA) as session:
            for i in range(0, amount, batch_size):
                self.load_batch(session, data[i:i+batch_size])
    
def load_batch(self, session, queries: list[str]):
    with session.transaction(TransactionType.WRITE) as transaction:
        for query in queries:
            transaction.query.insert(query)
        transaction.commit()

To speed up the loading, I tried to implement some form of multi-threading.

As per the Python API documentation, I used TypeDBOptions(parallel=True) to enable the use of parallel execution in the server; this in combination with a ThreadPoolExecutor from the concurrent.futures Python module. The modified load function looks as follows:

from concurrent.futures import ThreadPoolExecutor
from functools import partial

def load(self, data: list[str], batch_size=100):
    amount = len(data)
    batches = []
    for i in range(0, amount, self.batch_size):
        batches.append(data[i:i+self.batch_size])
    with TypeDB.core_driver(self.server_address) as driver:
        with driver.session(self.db_name, SessionType.DATA, TypeDBOptions(parallel=True)) as session:
            with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
                executor.map(partial(self.load_batch, session), batches)

However, the runtime remains unchanged. Some additional notes here:

  • It seemed more natural to run several threads concurrently committing transactions in the same session (as in the code above), however, I also tried first shifting the session context manager into the load_batch function and then shifting also the driver, so that multiple drivers were run in parallel; yet, there was no difference in the execution time.
  • I played with the batch size without successful results either.
  • I also followed the example in the documentation: TypeDB | Docs > Manual > Optimizing speed, although the parallel option is not considered. No noticeable difference was observed either.

At this point, I’d appreciate your help clarifying the following questions:

  1. As I understand, TypeDBOptions(parallel=True) acts on the server, so if it is enabled in the session, it should remain enabled for the transactions too. Is this correct? (in any case, i tried adding the parameter to both, the session and transaction, with no noticeable difference)
  2. Do any settings need to be modified in the server for the parallelization to take effect?
  3. Could it be that due to the complexity of the inserts (perhaps too many matching statements) this is the best performance that can be achieved?
  4. Is there perhaps a flaw in the implementation above that I’m missing?

Any insight or suggestions you can provide will be greatly appreciated.

Thanks in advance

Hello!

Parallelisation is indeed a good way to optimise your ingestion times.

  1. parallel=True in the TypeDBOptions is the default, and I would keep it - it allows the server to execute the lookups on the server that are required for the query to be done in parallel
  2. you don’t need to modify the server
  3. it’s unlikely your inserts are too complex to benefit from parallelisation

Which leaves me thinking we can improve your code to do the parallelisation!

Here’s the architecture we normally employ:

  1. Use multprocessing instead of multithreading, to avoid any Python GIL issues
  2. Each ‘worker’ receives a set of data to ingest. Each worker opens its own Driver + Sessions + transactions to load the data in batches

Here’s an example for TypeDB 3.0 I had lieing around, which you shoudl be able to translate into 2.x pretty easily!


def insert_batch(transaction, batch):
    for query in batch:
      # for 2.x: tx.query.insert(query)
      tx.query(query)
  
def worker(process_id: int, address: str, user: str, pw: str, db: str, data, batch_size: int, result_queue: multiprocessing.Queue):
    driver = TypeDB.driver(address, credentials=Credentials(user, pw), driver_options=DriverOptions(False, None))
    # for 2.x, also create a session for this worker
   
    process_time_ms = 0
    batch_num = 0

    while len(data) > 0:
        batch_num += 1
        batch = data[0:batch_size]
        data = data[batch_size:]

        print(f"Process {process_id}: Inserting batch {batch_num}")
        start_time = time.time()
        with driver.transaction(db, TransactionType.WRITE) as transaction:
            insert_batch(transaction, batch)
            transaction.commit()
        end_time = time.time()
        batch_time_ms = (end_time - start_time) * 1000
        process_time_ms += batch_time_ms
        print(f"Process {process_id}: Completed batch {batch_num}")

    result_queue.put((process_id, process_time_ms))

def main():
    parser = argparse.ArgumentParser(description='Load test data into TypeDB')
    parser.add_argument('--batch-size', type=int, default=100, help='Size of each batch')
    parser.add_argument('--address', type=str, default='localhost:1729', help='TypeDB server address')
    parser.add_argument('--database', type=str, default='load_test_db', help='Database name')
    parser.add_argument('--num-threads', type=int, default=1, help='Number of threads to use')
    args = parser.parse_args()

    # Calculate batches per process
    batches_per_process = args.num_batches // args.num_threads
    if batches_per_process == 0:
        batches_per_process = 1
        args.num_threads = args.num_batches

    # ... create a driver to create database, load schema, etc ...


    # Create a queue for collecting (timing) results from processes
    result_queue = multiprocessing.Queue()

    # Create and start processes
    processes = []
    for i in range(args.num_threads):
        process = multiprocessing.Process(
            target=worker, 
            args=(i, args.address, user, password, database_name, worker_data,  args.batch_size, result_queue)
        )
        processes.append(process)
        process.start()

    # Wait for all processes to complete
    for process in processes:
        process.join()

    # Collect results from the queue
    results = [0] * args.num_threads
    while not result_queue.empty():
        process_id, process_time_ms = result_queue.get()
        results[process_id] = process_time_ms

    # Calculate and print statistics
    total_time_ms = sum(results)
    print(f"\nTotal time: {total_time_ms:.2f}ms")
    print(f"Average batch time: {total_time_ms/args.num_batches:.2f} ms")
    print(f"Average time per insert: {total_time_ms/(args.num_batches*args.batch_size):.2f} ms")

Hi Joshua,

Thanks for taking the time to answer. I tried your suggestion but unfortunately the ingestion time did not show any improvement. However, based on your feedback I’ll assume for now that the problem is not in the logic of the implementation, so I will move to finding bottlenecks in the process.

Thanks again and have a nice day.

That is unexpected! You can check if it is parallelising if the server’s using more than 1 core - you should be able to fully saturate it with enough client-side threads. It’s likely that the GIL is kicking in to prevent efficient parallelism. We’ve seem almost linear speedups in TypeDB 2.0 with this approach of parallelisation.

I’ll drop my working tiny TypeDB 3.0 write benchmark test (which generates data on the fly) , which I believe you should be able to test TypeDB 2.0 with by tweaking the driver invocations :slight_smile:

import random
import string
from datetime import datetime, timedelta
import argparse
import time
import multiprocessing
from typing import List
from typedb.driver import *

def generate_random_name(min_length: int = 3, max_length: int = 12) -> str:
    return ''.join(random.choices(string.ascii_letters, k=random.randint(min_length, max_length)))

def generate_random_email() -> str:
    name = generate_random_name(5, 10)
    domain = generate_random_name(3, 6)
    return f"{name}@{domain}.com"

def generate_random_date() -> datetime.date:
    current_datetime = datetime.now()
    random_days = random.randint(0, 10000)
    return current_datetime + timedelta(days=random_days)

def generate_random_address() -> str:
    street_number = random.randint(1, 9999)
    street_name = generate_random_name(5, 15)
    city = generate_random_name(5, 10)
    state = ''.join(random.choices(string.ascii_uppercase, k=2))
    zip_code = ''.join(random.choices(string.digits, k=5))
    return f"{street_number} {street_name} St, {city}, {state} {zip_code}"

def define_schema(transaction: Transaction):
    # Define person entity type
    transaction.query("""
        define
        entity person,
            owns name,
            owns age,
            owns email @key,
            owns birth_date @card(1),
            owns home_address;
        attribute name, value string;
        attribute age, value integer;
        attribute email, value string;
        attribute birth_date, value datetime;
        attribute home_address, value string;
    """)

def insert_person_batch(transaction: Transaction, batch_size: int):
    for _ in range(batch_size):
        name = generate_random_name()
        age = random.randint(0, 100)
        email = generate_random_email()
        birth_date = generate_random_date()
        home_address = generate_random_address()
        
        transaction.query(f"""
            insert
            $p isa person, has name "{name}", has age {age}, 
                has email "{email}", has birth_date {birth_date.isoformat()}, 
                has home_address "{home_address}";
        """)

def worker(process_id: int, address: str, user: str, pw: str, db: str, batches_per_process: int, batch_size: int, total_batches: int, result_queue: multiprocessing.Queue):
    driver = TypeDB.driver(address, credentials=Credentials(user, pw), driver_options=DriverOptions(False, None))
    start_batch = process_id * batches_per_process
    end_batch = min(start_batch + batches_per_process, total_batches)
    process_time_ms = 0

    for batch_num in range(start_batch, end_batch):
        print(f"Process {process_id}: Inserting batch {batch_num + 1}/{total_batches}")
        start_time = time.time()
        with driver.transaction(db, TransactionType.WRITE) as transaction:
            insert_person_batch(transaction, batch_size)
            transaction.commit()
        end_time = time.time()
        batch_time_ms = (end_time - start_time) * 1000
        process_time_ms += batch_time_ms
        print(f"Process {process_id}: Completed batch {batch_num + 1}")

    result_queue.put((process_id, process_time_ms))

def main():
    parser = argparse.ArgumentParser(description='Load test data into TypeDB')
    parser.add_argument('--num-batches', type=int, default=1000, help='Number of batches to insert')
    parser.add_argument('--batch-size', type=int, default=500, help='Size of each batch')
    parser.add_argument('--address', type=str, default='localhost:1729', help='TypeDB server address')
    parser.add_argument('--database', type=str, default='load_test_db', help='Database name')
    parser.add_argument('--num-threads', type=int, default=1, help='Number of threads to use')
    args = parser.parse_args()

    user = "admin"
    pw = "password"

    # Connect to TypeDB with a temporary driver to set up the database and schema
    driver = TypeDB.driver(args.address, credentials=Credentials(user, pw), driver_options=DriverOptions(False, None))
    
    db = args.database
    # Create database if it doesn't exist
    if driver.databases.contains(db):
        driver.databases.get(db).delete()
    driver.databases.create(db)
    
    # Define schema
    with driver.transaction(db, TransactionType.SCHEMA) as transaction:
        define_schema(transaction)
        transaction.commit()

    # Calculate batches per process
    batches_per_process = args.num_batches // args.num_threads
    if batches_per_process == 0:
        batches_per_process = 1
        args.num_threads = args.num_batches

    # Create a queue for collecting results from processes
    result_queue = multiprocessing.Queue()

    # Create and start processes
    processes = []
    for i in range(args.num_threads):
        process = multiprocessing.Process(
            target=worker, 
            args=(i, args.address, user, pw, db, batches_per_process, args.batch_size, args.num_batches, result_queue)
        )
        processes.append(process)
        process.start()

    # Wait for all processes to complete
    for process in processes:
        process.join()

    # Collect results from the queue
    results = [0] * args.num_threads
    while not result_queue.empty():
        process_id, process_time_ms = result_queue.get()
        results[process_id] = process_time_ms

    # Calculate and print statistics
    total_time_ms = sum(results)
    print(f"\nTotal time: {total_time_ms:.2f}ms")
    print(f"Average batch time: {total_time_ms/args.num_batches:.2f} ms")
    print(f"Average time per insert: {total_time_ms/(args.num_batches*args.batch_size):.2f} ms")

if __name__ == "__main__":
    main()