๐Ÿ“ฆ encode / databases

๐Ÿ“„ test_integration.py ยท 110 lines
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110import os

import pytest
import sqlalchemy
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.testclient import TestClient

from databases import Database, DatabaseURL

assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set."

DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")]

metadata = sqlalchemy.MetaData()

notes = sqlalchemy.Table(
    "notes",
    metadata,
    sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
    sqlalchemy.Column("text", sqlalchemy.String(length=100)),
    sqlalchemy.Column("completed", sqlalchemy.Boolean),
)


@pytest.fixture(autouse=True, scope="module")
def create_test_database():
    # Create test databases
    for url in DATABASE_URLS:
        database_url = DatabaseURL(url)
        if database_url.scheme in ["mysql", "mysql+aiomysql"]:
            url = str(database_url.replace(driver="pymysql"))
        elif database_url.scheme in [
            "postgresql+aiopg",
            "sqlite+aiosqlite",
            "postgresql+asyncpg",
        ]:
            url = str(database_url.replace(driver=None))
        engine = sqlalchemy.create_engine(url)
        metadata.create_all(engine)

    # Run the test suite
    yield

    # Drop test databases
    for url in DATABASE_URLS:
        database_url = DatabaseURL(url)
        if database_url.scheme in ["mysql", "mysql+aiomysql"]:
            url = str(database_url.replace(driver="pymysql"))
        elif database_url.scheme in [
            "postgresql+aiopg",
            "sqlite+aiosqlite",
            "postgresql+asyncpg",
        ]:
            url = str(database_url.replace(driver=None))
        engine = sqlalchemy.create_engine(url)
        metadata.drop_all(engine)


def get_app(database_url):
    database = Database(database_url, force_rollback=True)
    app = Starlette()

    @app.on_event("startup")
    async def startup():
        await database.connect()

    @app.on_event("shutdown")
    async def shutdown():
        await database.disconnect()

    @app.route("/notes", methods=["GET"])
    async def list_notes(request):
        query = notes.select()
        results = await database.fetch_all(query)
        content = [
            {"text": result["text"], "completed": result["completed"]}
            for result in results
        ]
        return JSONResponse(content)

    @app.route("/notes", methods=["POST"])
    async def add_note(request):
        data = await request.json()
        query = notes.insert().values(text=data["text"], completed=data["completed"])
        await database.execute(query)
        return JSONResponse({"text": data["text"], "completed": data["completed"]})

    return app


@pytest.mark.parametrize("database_url", DATABASE_URLS)
def test_integration(database_url):
    app = get_app(database_url)

    with TestClient(app) as client:
        response = client.post("/notes", json={"text": "example", "completed": True})
        assert response.status_code == 200
        assert response.json() == {"text": "example", "completed": True}

        response = client.get("/notes")
        assert response.status_code == 200
        assert response.json() == [{"text": "example", "completed": True}]

    with TestClient(app) as client:
        # Ensure sessions are isolated
        response = client.get("/notes")
        assert response.status_code == 200
        assert response.json() == []