Source code for ssp.spark.streaming.test_spark_stream_producer

import os
from pytest_spark import spark_session

from ssp.spark.streaming.common.twitter_streamer_base import TwitterStreamerBase


def _get_test_spark_stream(spark_session):
    schema = TwitterStreamerBase._get_schema()
    test_files_path = "file:///" + os.path.abspath("data/streams/tweets/")
    sdf = spark_session.readStream.format("json").schema(schema).load(test_files_path)
    return sdf

[docs]def test_spark_stream(spark_session): sdf = _get_test_spark_stream(spark_session) count_acc = spark_session.sparkContext.accumulator(0) def foreach_batch_function(df, epoch_id, count_acc): # Transform and write batchDF count = df.count() count_acc += count sdf.writeStream.foreachBatch(lambda df, epoch_id : foreach_batch_function(df=df, epoch_id=epoch_id, count_acc=count_acc)).start().processAllAvailable() assert count_acc.value == 1000