Implement PostgreSQL Pool connection in Rust

At Craft AI, we build a new product so data scientists can code and push, quickly and easily, in production, their machine learning algorithms. Our purpose is to make life easier for data scientists, for example, we handle data storage in a nice way so data scientists do not have to bother with saving and loading data from a database.

31/03/2022

Tech

Tous les articles

Sommaire

À télécharger

For this purpose, we use a PostgreSQL client written in Rust. Since we have a runtime with Tokio, the Deadpool-postgres crate was chosen.

As the scaling capability of our product is a central question we wanted to use a pool instead of using one connection at a time.

We wrote this article describing the first steps for implementing Postgresql pool connection in Rust with the Tokio Runtime. There might still be room for improvement but to the best of our knowledge, this is the only article that shows an example with Json and Uuid columns, with transactions with a retry mechanism.

That’s why we share this article. 

Furthermore, we welcome any reviews on this article. That would be a great opportunity for us to improve it.

Requirements: 
  • To be fairly comfortable with Rust and its new paradigms such as: Ownership, Borrowchecker, Lifetime. (>= v1.56.0)
  • To already have a basic experience with the Tokio Runtime. (>= v1.12.0)
  • To have basic knowledge of PostgreSQL. (>= v9)

Setting up Pool, connection and runtime

Let’s start with the installation of the required crates:

[dependencies]
deadpool-postgres = "0.9.0" 
postgres = "0.19.1" 
tokio = { version = "1.12.0", features = ["rt", "rt-multi-thread", "macros"] }

Deadpool-postgres is needed for the pool connection. Obviously postgres is required too and we also need the tokio runtime.

The first step is to establish the connection to the database.

For this purpose we need to retrieve database configuration:

use deadpool_postgres::{ManagerConfig, RecyclingMethod};
// Helper function to read environment variable with a default value if
// the environment variable is not set.
use super::super::shared::env::env_parse;

fn get_db_config() -> deadpool_postgres::Config {
    let mut config = deadpool_postgres::Config::new();
    config.user = Some(env_parse("DB_USER", "postgres".into()));
    config.password = Some(env_parse("DB_PASSWORD", "password".into()));
    config.dbname = Some(env_parse("DB_NAME", "postgres".into()));
    config.host = Some(env_parse("DB_HOSTNAME", "172.17.0.2".into()));

    config.manager =
       Some(ManagerConfig { recycling_method: RecyclingMethod::Fast });

    config
}

And then create the Pool connection client:

use deadpool_postgres::Pool;
use postgres::NoTls;
pub fn create_pool() -> Result<Pool, String> {
    Ok(get_db_config().create_pool(NoTls).map_err(|err| err.to_string())?)
}

And finally the service with Tokio runtime and a messaging manager (Kafka for instance but it could also be RabbitMQ), where each message received through Kafka is handled in an asynchronous task.

// All the required imports are done here
pub async fn run_service(topics: Vec<String>) -> Result<(), String> {
    // Stuff related to kafka and the tokio runtime
    let pool: Pool = create_pool()?;
    let stream_processor = consumer.stream().try_for_each(|message| {
        let pool = pool.clone();    
        async move {
            let _ = tokio::spawn(async move {
                // Do stuff with the pool
            }).await;
            Ok(())
        }
    }
    stream_processor.await.map_err(|err| format!("Error {:?}", err)?;
    Ok(())
}

Notice: as it is explained in the documentation, cloning the `pool` is cheap, so we do it for each message.

CRUD operations

Now that everything is set up, we can dive into the hard part of the code: database queries.

A very useful and tiny helper function is to get a connection from the pool.

async fn get_connection(pool: &Pool) -> Result<Client, String> {
    pool.get().await.map_err(|err| err.to_string())
}

How we handle errors is not this article subject, here we will just raise an error with a String as the error message.

For purposes of illustration, we will use a simple table called people.

Each record has an UUID, a name (Text) and a data field (Json):

CREATE EXTENSION pgcrypto; # required for uuid
CREATE TABLE IF NOT EXISTS people (    
    id uuid DEFAULT gen_random_uuid() PRIMARY KEY,
    name text NOT NULL,
    data json,
);

At each insertion, the id (UUID) is automatically generated.

So now, we need to update our Cargo.toml file with all the required crates:

[dependencies]
deadpool-postgres = "0.9.0"
postgres = { version = "0.19.1", features = ["with-serde_json-1", "with-uuid-0_8"] }
serde = "1.0.127"
serde_json = "1.0.66"
tokio = { version = "1.12.0", features = ["rt", "rt-multi-thread", "macros"] }
uuid = { version = "0.8.2", features = ["serde", "v4"] }

We need serde and serde_json crates for serialization and deserialization for the Json column and the uuid crate for the uuid type.

This is the struct for the Json column “data”:

use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
pub struct VeryComplicatedData {
    pub many_many_fields: u64,
    pub dict: HashMap<String, String>,    
    // etc...
}

And then the struct for the record:

use uuid::Uuid;

pub struct People {    
    pub id: Uuid,
    pub name: String,
    pub data: Option<VeryComplicatedData>,
}

Finally, the From trait implementation for our struct:

use serde_json::value::Value;

impl From<Row> for People {
    fn from(row: Row) -> Self {
        let val = row.get::<&str, Option<Value>>("data");
        let data = match val {
            Some(Value::Object(x)) =>                 
                Some(serde_json::from_value(Value::Object(x)).unwrap()),
            _ => None,
        };

        Self {
            id: row.get("id"),
            name: row.get("name"),
            data,
        }
    }
}

I haven’t found a better way to deserialize the Json column other than with the `serde_json::value::Value` and then with the `serde_json::from_value` method.

I will let the reader correctly handle the error if any, instead of using the `unwrap()` method.

Create a record

To create a record, it is straightforward:

use serde_json::json;

async fn insert_people(
    pool: &Pool, 
    name: String, 
    data: VeryComplicatedData,
) -> Result<(), String> {
    get_connection(pool).await?
        .execute(
            "INSERT INTO people (name) VALUES ($1, $2)", 
            &[&name], &json!(data)]
        )
        .await
        .map(|_| ())
        .map_err(|err| format!("Error while insertion: {}", err))
}

Read records

Now, we are ready to read records in the database:

async fn get_people(pool: &Pool, name: String) -> Result<Vec<People>, String> {
    get_connection(pool).await?
        .query("SELECT * FROM people WHERE name = $1;", &[&name])
        .await
        .map(|rows| 
             rows
               .map(|row| row.into())
               .collect::<Vec<People>>()
        )
        .map_err(|err| format!("Error while reading: {}", err))
}

Update records

To update a row or delete a row, it is pretty straightforward.

async fn update_people(pool: &Pool, id: Uuid, name: String) -> Result<(), String> {
    get_connection(pool).await?
        .execute("UPDATE people SET name = $1 WHERE id = $2;", &[&name, &id])
        .await
        .map(|_| ())
        .map_err(|err| format!("Error while updating: {}", err))
}

Delete records

async fn delete_people(pool: &Pool, id: Uuid) -> Result<(), String> {
    get_connection(pool).await?
        .execute("DELETE FROM people WHERE id = $1;", &[&id])
        .await
        .map(|_| ())
        .map_err(|err| format!("Error while deleting: {}", err))
}

Add multiple rows at once

According to the rust postgres crate author, there are two ways to handle this: we can either prepare the query and populate it with our data or use the COPY query.

Here we chose to prepare the query. For the COPY query method, I let the reader click on the link above.

First, we need to create the query string. Since we use parameterized queries, we need to index each element passed to the query.

Caution, index starts at 1 and not 0.

fn get_values_param(people: &Vec<(String, VeryComplicatedData)>) -> String {
    let values_param = people.iter().enumerate()
        .map(|(i, (_, _))| format!("(${}, ${})", i * 2 + 1, i * 2 + 2))
        .collect::<Vec<String>>()
        .join(", ");

    format!("INSERT INTO people (name, data) VALUES {};", values_param) 
}

So at the end we have: 

async fn insert_people(
    pool: &Pool,
    people: Vec<(String, VeryComplicatedData)>,
) -> Result<(), String> {
    let query_str = get_values_param(&people);
    let fmt_people = people.into_iter()
        .map(|(name, data)| (name, json!(data)))
        .collect::<Vec<(String, Value)>>();
    let mut values: Vec<&(dyn ToSql + Sync)> = Vec::new();
    for (name, jdata) in fmt_people.iter() {
        values.push(name);
        values.push(jdata);
    }

    get_connection(pool).await?
        .execute(&query_str, &values[..])
        .await
        .map(|_| ())
        .map_err(|err| format!("Error while insertions: {}", err))
}

Transaction with retry

When we use transaction, there can be two issues:

  • Lock: when the request needs to access to a table but this table is already locked by another request
  • Deadlock: when several requests need to access several tables and try to lock them at the same time

We are very lucky because a clean error is returned for each of the previous cases.

Hence, this very helpful helper :

use postgres::error::SqlState;

fn should_retry(sql_error: &postgres::Error) -> bool {
    sql_error
        .code()
        .map(|e| {
            *e == SqlState::T_R_SERIALIZATION_FAILURE
                || *e == SqlState::T_R_DEADLOCK_DETECTED
        })
        .unwrap_or(false)
}

All the SqlState can be found here.

We tried to write a high level function to query the database with the retry mechanism, but actually it is not as simple as that. Instead, we will use this helper function in a very procedural way.

// We could implement an infinite retry, but it might be a problem, it could loop
// forever, it is better to fail at some point. In our production code, we use a
// threshold of 10.
const QUERY_RETRY_COUNT: usize = 10;

async fn get_people(pool: &Pool, name: String) -> Result<Vec<People>, String> {
    let mut conn = get_connection(pool).await?;


    // We add a for loop to manage the maximum number of attempts.
    // After this number of attempts, we return an Err.
    for _ in 0..QUERY_RETRY_COUNT {
        // First part is the transaction query.
        let transaction = conn.transaction().await.map_err(|err| err.to_string())?;
        let result = transaction.query(
            "SELECT * FROM people WHERE name = $1;",
            &[&name],
        ).await;
        let result_commit = transaction.commit().await;


        // Second part, we check the transaction commit result, if needed, we jump 
        // to the next for loop iteration for a retry, otherwise we handle the 
        // result.
        if result_commit.is_err() 
            && result_commit.map_err(|err| should_retry(&err)).unwrap_err() {
            continue;
        }


        // Third and last part, we handle the result. A “retry error” can be during
        // the transaction commit, but it can also be during the query.
        match result {
            // The query succeeded, so we returned the result.
            Ok(rows) => return Ok(Some(rows.into_iter()
                .map(|row| row.into())
                .collect::<Vec<People>>())),
            // The query failed, but it can be a “retry failure”, if it is the case,
            // we jump to the next for loop iteration for a new retry.
            Err(err) => {
                if should_retry(&err) {
                    continue;
                }


                // Otherwise we return the error.
                return Err(format!(
                    "Unable to get people with name \"{}\": {}",
                    name, err
                ));
            }
        };
    }

    Err("Too many retry.".into())
}

If you need to use another isolation level, you can use the `build_transaction` method :

let transaction = conn
    .build_transaction()
    .isolation_level(IsolationLevel::Serializable)
    .start()
    .await
    .map_err(|err| err.to_string())?;

For more details, you can read the documentation here.

To dive deeper into this topic, there is the Rust forum where you can ask any question: https://users.rust-lang.org/

I hope this will help you. 

We welcome any suggestions to improve this article. Feel free to share your experience as well.

References:

Photo by Jorgen Hendriksen on Unsplash

Une plateforme compatible avec tout l’écosystème

aws
Azure
Google Cloud
OVH Cloud
scikit-lean
PyTorch
Tensor Flow
XGBoost
jupyter
PC
Python
R
Rust
mongo DB