Source code for pragma_sdk.onchain.mixins.oracle

import time
from typing import Callable, Coroutine, Dict, List, Optional, Sequence

from deprecated import deprecated
from starknet_py.contract import InvokeResult
from starknet_py.net.account.account import Account
from starknet_py.net.client import Client

from pragma_sdk.common.logging import get_pragma_sdk_logger
from pragma_sdk.common.utils import felt_to_str, str_to_felt
from pragma_sdk.common.types.entry import Entry, FutureEntry, SpotEntry, GenericEntry
from pragma_sdk.common.types.types import AggregationMode
from pragma_sdk.common.types.asset import Asset
from pragma_sdk.common.types.pair import Pair
from pragma_sdk.common.types.types import (
    DataTypes,
    Address,
    Decimals,
    UnixTimestamp,
)

from pragma_sdk.onchain.types.execution_config import ExecutionConfig
from pragma_sdk.onchain.types import (
    OracleResponse,
    Checkpoint,
    Contract,
    BlockId,
)

logger = get_pragma_sdk_logger()


[docs] class OracleMixin: publisher_registry: Contract client: Client account: Account execution_config: ExecutionConfig oracle: Contract is_user_client: bool = False track_nonce: Callable[[object, int, int], Coroutine[None, None, None]]
[docs] @deprecated async def publish_spot_entry( self, pair_id: int, value: int, timestamp: UnixTimestamp, source: int, publisher: int, volume: int = 0, ) -> InvokeResult: if not self.is_user_client: raise AttributeError( "Must set account. " "You may do this by invoking " "self._setup_account_client(private_key, account_contract_address)" ) invocation = await self.oracle.functions["publish_data"].invoke( new_entry={ "Spot": { "base": { "timestamp": timestamp, "source": source, "publisher": publisher, }, "price": value, "pair_id": pair_id, "volume": volume, } }, execution_config=self.execution_config, ) return invocation
[docs] async def publish_many(self, entries: List[Entry]) -> List[InvokeResult]: if not entries: logger.warning("publish_many received no entries to publish. Skipping") return [] spot_entries: List[Entry] = [ entry for entry in entries if isinstance(entry, SpotEntry) ] future_entries: List[Entry] = [ entry for entry in entries if isinstance(entry, FutureEntry) ] generic_entries: List[Entry] = [ entry for entry in entries if isinstance(entry, GenericEntry) ] invocations = [] invocations.extend(await self._publish_entries(spot_entries, DataTypes.SPOT)) invocations.extend( await self._publish_entries(future_entries, DataTypes.FUTURE) ) invocations.extend( await self._publish_entries(generic_entries, DataTypes.GENERIC) ) return invocations
async def _publish_entries( self, entries: List[Entry], data_type: DataTypes ) -> List[InvokeResult]: if len(entries) == 0: return [] invocations = [] match data_type: case DataTypes.SPOT: serialized_entries = SpotEntry.serialize_entries(entries) case DataTypes.FUTURE: serialized_entries = FutureEntry.serialize_entries(entries) case DataTypes.GENERIC: serialized_entries = GenericEntry.serialize_entries(entries) pagination = self.execution_config.pagination if pagination: for i in range(0, len(serialized_entries), pagination): entries_subset = serialized_entries[i : i + pagination] invocation = await self._invoke_publish(entries_subset, data_type) invocations.append(invocation) self._log_transaction(invocation, len(entries_subset), data_type) else: invocation = await self._invoke_publish(serialized_entries, data_type) invocations.append(invocation) self._log_transaction(invocation, len(serialized_entries), data_type) return invocations async def _invoke_publish( self, entries: List[Dict], data_type: DataTypes ) -> InvokeResult: return await self.oracle.functions["publish_data_entries"].invoke( new_entries=[{data_type: entry} for entry in entries], execution_config=self.execution_config, callback=self.track_nonce, ) def _log_transaction( self, invocation: InvokeResult, entry_count: int, data_type: DataTypes ): logger.debug( f"Sent {entry_count} updated {data_type.name.lower()} entries with transaction {hex(invocation.hash)}" )
[docs] @deprecated async def get_spot_entries( self, pair_id, sources=None, block_id: Optional[BlockId] = "latest", ) -> List[SpotEntry]: if sources is None: sources = [] if isinstance(pair_id, str): pair_id = str_to_felt(pair_id.upper()) elif not isinstance(pair_id, int): raise TypeError( "Pair ID must be string (will be converted to felt) or integer" ) (response,) = await self.oracle.functions["get_data_entries_for_sources"].call( Asset(DataTypes.SPOT, pair_id, None).serialize(), sources, block_number=block_id, ) entries = response[0] return [SpotEntry.from_dict(dict(entry.value)) for entry in entries]
[docs] async def get_all_sources( self, asset: Asset, block_id: Optional[BlockId] = "latest", ) -> List[str]: """ Query on-chain all sources used for a given asset. :param asset: Asset :param block_id: Block number or Block Tag :return: List of sources """ (response,) = await self.oracle.functions["get_all_sources"].call( asset.serialize(), block_number=block_id ) return [felt_to_str(source) for source in response]
[docs] @deprecated async def get_future_entries( self, pair_id: str | int, expiration_timestamp: UnixTimestamp, sources: Optional[List[str | int]] = None, block_id: Optional[BlockId] = "latest", ) -> List[FutureEntry]: if sources is None: sources = [] if isinstance(pair_id, str): pair_id = str_to_felt(pair_id.upper()) elif not isinstance(pair_id, int): raise TypeError( "Pair ID must be string (will be converted to felt) or integer" ) (response,) = await self.oracle.functions["get_data_entries_for_sources"].call( Asset(DataTypes.FUTURE, pair_id, expiration_timestamp).serialize(), sources, block_number=block_id, ) entries = response[0] return [FutureEntry.from_dict(dict(entry.value)) for entry in entries]
[docs] async def get_spot( self, pair_id: str | int, aggregation_mode: AggregationMode = AggregationMode.MEDIAN, sources: Optional[List[str | int]] = None, block_id: Optional[BlockId] = "latest", ) -> OracleResponse: """ Query the Oracle contract for the data of a spot asset. :param pair_id: Pair ID :param aggregation_mode: AggregationMode :param sources: List of sources, if None will use all sources :param block_id: Block number or Block Tag :return: OracleResponse """ if isinstance(pair_id, str): pair_id = str_to_felt(pair_id.upper()) elif not isinstance(pair_id, int): raise TypeError( "Pair ID must be string (will be converted to felt) or integer" ) if sources is None: (response,) = await self.oracle.functions["get_data"].call( Asset(DataTypes.SPOT, pair_id, None).serialize(), aggregation_mode.serialize(), block_number=block_id, ) else: (response,) = await self.oracle.functions["get_data_for_sources"].call( Asset(DataTypes.SPOT, pair_id, None).serialize(), aggregation_mode.serialize(), sources, block_number=block_id, ) response = dict(response) return OracleResponse( response["price"], response["decimals"], response["last_updated_timestamp"], response["num_sources_aggregated"], response["expiration_timestamp"], )
[docs] async def get_entry( self, pair_id: str | int, data_type: DataTypes, publisher: str | int, source: str | int, expiration_timestamp: Optional[int] = None, block_id: Optional[BlockId] = "latest", ) -> Entry: """ Query the Oracle contract for the entry of a publisher for a source. :param pair_id: Pair ID :param data_type: DataTypes :param publisher: Publisher to check entry for :param source: Source to check :param expiration_timestamp: Optional, expiration timestamp for futures. Defaults to 0. :param block_id: Block number or Block Tag :return: Entry """ if data_type == DataTypes.FUTURE and expiration_timestamp is None: expiration_timestamp = 0 if isinstance(pair_id, str): pair_id = str_to_felt(pair_id.upper()) elif not isinstance(pair_id, int): raise TypeError( "Pair ID must be string (will be converted to felt) or integer" ) match data_type: case DataTypes.SPOT | DataTypes.GENERIC: asset = Asset(data_type, pair_id, None) case DataTypes.FUTURE: asset = Asset(data_type, pair_id, expiration_timestamp) (response,) = await self.oracle.functions["get_data_entry"].call( asset.serialize(), source, publisher, block_number=block_id, ) response = response.as_dict() response = dict(response["value"]) entry: Entry match data_type: case DataTypes.SPOT: entry = SpotEntry.from_dict(response) case DataTypes.FUTURE: entry = FutureEntry.from_dict(response) case DataTypes.GENERIC: entry = GenericEntry.from_dict(response) return entry
[docs] async def get_future( self, pair_id: str | int, expiry_timestamp: UnixTimestamp, aggregation_mode: AggregationMode = AggregationMode.MEDIAN, sources: Optional[List[str | int]] = None, block_id: Optional[BlockId] = "latest", ) -> OracleResponse: """ Query the Oracle contract for the data of a future asset. :param pair_id: Pair ID :param expiry_timestamp: Expiry timestamp of the future contract :param aggregation_mode: AggregationMode :param sources: List of sources, if None will use all sources :param block_id: Block number or Block Tag :return: OracleResponse """ if isinstance(pair_id, str): pair_id = str_to_felt(pair_id.upper()) elif not isinstance(pair_id, int): raise TypeError( "Pair ID must be string (will be converted to felt) or integer" ) if sources is None: (response,) = await self.oracle.functions["get_data"].call( Asset(DataTypes.FUTURE, pair_id, expiry_timestamp).serialize(), aggregation_mode.serialize(), block_number=block_id, ) else: (response,) = await self.oracle.functions["get_data_for_sources"].call( Asset(DataTypes.FUTURE, pair_id, expiry_timestamp).serialize(), aggregation_mode.serialize(), sources, block_number=block_id, ) response = dict(response) return OracleResponse( response["price"], response["decimals"], response["last_updated_timestamp"], response["num_sources_aggregated"], response["expiration_timestamp"], )
[docs] async def get_generic( self, key: str | int, sources: Optional[List[str | int]] = None, block_id: Optional[BlockId] = "latest", ) -> GenericEntry: """ Query the Oracle contract to retrieve the :param key: Key ID of the generic entry :param block_id: Block number or Block Tag :return: GenericEntry """ if isinstance(key, str): key = str_to_felt(key.upper()) elif not isinstance(key, int): raise TypeError( "Generic entry key must be string (will be converted to felt) or integer" ) if sources is None: (response,) = await self.oracle.functions["get_data_entries"].call( Asset(DataTypes.GENERIC, key, None).serialize(), block_number=block_id, ) else: (response,) = await self.oracle.functions[ "get_data_entries_for_sources" ].call( Asset(DataTypes.GENERIC, key, None).serialize(), sources, block_number=block_id, ) # NOTE: We only return the latest entry because there shouldn't more # than one entry with the same key. response = response[-1].as_dict() entry = dict(response["value"]) return GenericEntry( key=entry["key"], value=entry["value"], timestamp=entry["base"]["timestamp"], source=entry["base"]["source"], publisher=entry["base"]["publisher"], )
[docs] async def get_decimals( self, asset: Asset, block_id: Optional[BlockId] = "latest", ) -> Decimals: """ Query on-chain the decimals for a given asset :param asset: Asset :param block_id: Block number or Block Tag :return: Decimals """ (response,) = await self.oracle.functions["get_decimals"].call( asset.serialize(), block_number=block_id, ) return response # type: ignore[no-any-return]
[docs] async def set_future_checkpoints( self, pair_ids: Sequence[int], expiry_timestamps: Sequence[int], aggregation_mode: AggregationMode = AggregationMode.MEDIAN, ) -> InvokeResult: assert len(pair_ids) == len(expiry_timestamps) if not self.is_user_client: raise AttributeError( "Must set account. " "You may do this by invoking " "self._setup_account_client(private_key, account_contract_address)" ) invocation = None pagination = self.execution_config.pagination if pagination: index = 0 while index < len(pair_ids): pair_ids_subset = pair_ids[index : index + pagination] expiries_subset = expiry_timestamps[index : index + pagination] invocation = await self.oracle.set_checkpoints.invoke( [ Asset(DataTypes.FUTURE, pair_id, expiry).serialize() for pair_id, expiry in zip(pair_ids_subset, expiries_subset) ], aggregation_mode.serialize(), max_fee=self.execution_config.max_fee, callback=self.track_nonce, ) index += pagination logger.info( "Set future checkpoints for %d pair IDs with transaction %s", len(pair_ids_subset), hex(invocation.hash), ) else: invocation = await self.oracle.set_checkpoints.invoke( [ Asset(DataTypes.FUTURE, pair_id, expiry).serialize() for pair_id, expiry in zip(pair_ids, expiry_timestamps) ], aggregation_mode.serialize(), max_fee=self.execution_config.max_fee, callback=self.track_nonce, ) return invocation
[docs] async def set_checkpoints( self, pair_ids: Sequence[str | int], aggregation_mode: AggregationMode = AggregationMode.MEDIAN, ) -> InvokeResult: """ Set checkpoints for a list of pair IDs. :param pair_ids: List of pair IDs :param aggregation_mode: AggregationMode :return: InvokeResult """ if not self.is_user_client: raise AttributeError( "Must set account. " "You may do this by invoking " "self._setup_account_client(private_key, account_contract_address)" ) invocation = None pagination = self.execution_config.pagination if pagination: index = 0 while index < len(pair_ids): pair_ids_subset = pair_ids[index : index + pagination] invocation = await self.oracle.set_checkpoints.invoke( [ Asset(DataTypes.SPOT, pair_id, None).serialize() for pair_id in pair_ids_subset ], aggregation_mode.serialize(), max_fee=self.execution_config.max_fee, callback=self.track_nonce, ) index += pagination logger.info( "Set spot checkpoints for %d pair IDs with transaction %s", len(pair_ids_subset), hex(invocation.hash), ) else: invocation = await self.oracle.set_checkpoints.invoke( [ Asset(DataTypes.SPOT, pair_id, None).serialize() for pair_id in pair_ids ], aggregation_mode.serialize(), max_fee=self.execution_config.max_fee, callback=self.track_nonce, ) return invocation
[docs] async def get_latest_checkpoint( self, pair_id: str | int, data_type: DataTypes, aggregation_mode: AggregationMode = AggregationMode.MEDIAN, expiration_timestamp: Optional[UnixTimestamp] = None, ) -> Checkpoint: if expiration_timestamp is not None and data_type == DataTypes.SPOT: raise ValueError("expiration_timestamp for SPOT should be None.") (response,) = await self.oracle.functions["get_latest_checkpoint"].call( Asset(data_type, pair_id, expiration_timestamp).serialize(), aggregation_mode.serialize(), ) return Checkpoint( # type: ignore[no-any-return] timestamp=response["timestamp"], value=response["value"], aggregation_mode=AggregationMode(response["aggregation_mode"].variant), num_sources_aggregated=response["num_sources_aggregated"], )
[docs] async def get_last_checkpoint_before( self, pair_id: str | int, data_type: DataTypes, timestamp: UnixTimestamp, aggregation_mode: AggregationMode = AggregationMode.MEDIAN, expiration_timestamp: Optional[UnixTimestamp] = None, ) -> Checkpoint: if expiration_timestamp is not None and data_type == DataTypes.SPOT: raise ValueError("expiration_timestamp for SPOT should be None.") (response,) = await self.oracle.functions["get_last_checkpoint_before"].call( Asset(data_type, pair_id, expiration_timestamp).serialize(), timestamp, aggregation_mode.serialize(), ) return Checkpoint( # type: ignore[no-any-return] timestamp=response["timestamp"], value=response["value"], aggregation_mode=AggregationMode(response["aggregation_mode"].variant), num_sources_aggregated=response["num_sources_aggregated"], )
[docs] async def get_admin_address(self) -> Address: """ Return the admin address of the Oracle contract. """ (response,) = await self.oracle.functions["get_admin_address"].call() return response # type: ignore[no-any-return]
[docs] async def update_oracle( self, implementation_hash: int, ) -> InvokeResult: """ Update the Oracle contract to a new implementation. :param implementation_hash: New implementation hash :return: InvokeResult """ if not self.is_user_client: raise AttributeError( "Must set account. " "You may do this by invoking " "self._setup_account_client(private_key, account_contract_address)" ) invocation = await self.oracle.functions["upgrade"].invoke( implementation_hash, max_fee=self.execution_config.max_fee, ) return invocation
[docs] async def get_time_since_last_published_spot( self, pair: Pair, publisher: str, block_id: Optional[BlockId] = "latest", ) -> int: """ Get the time since the last published spot entry by a publisher for a given pair. Will return a large number if no entry is found. :param pair: Pair :param publisher: Publisher name e.g "PRAGMA" :param block_id: Block number or Block Tag :return: Time since last published entry """ all_entries = await self.get_spot_entries(pair.id, block_id=block_id) entries = [ entry for entry in all_entries if entry.base.publisher == str_to_felt(publisher) ] if len(entries) == 0: return 1000000000 # arbitrary large number max_timestamp = max(entry.base.timestamp for entry in entries) diff = int(time.time()) - max_timestamp return diff