"""A minimal set of utilities used by KodeAgent.
This module will be copied along with code for CodeAgent, so keep it minimum.
"""
import base64
import inspect
import logging
import mimetypes
import os
import re
from collections.abc import Callable
from typing import Any
import litellm
import pydantic as pyd
import requests
from tenacity import (
AsyncRetrying,
RetryError,
before_sleep_log,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from .usage_tracker import UsageMetrics
DATA_TYPES = {
int: 'integer',
float: 'number',
str: 'string',
bool: 'boolean',
list: 'array',
dict: 'object',
Any: 'string',
}
DEFAULT_MAX_LLM_RETRIES = 3
# Max characters of file content to inline in a user message
MAX_FILE_CONTENT_LENGTH = 5_000
LOGGERS_TO_SUPPRESS = [
'asyncio',
'cookie_store',
'e2b',
'e2b_code_interpreter',
'hpack',
'httpx',
'httpcore',
'hyper_util',
'h2',
'langfuse',
'LiteLLM',
'litellm',
'openai',
'pdfminer',
'primp',
'reqwest',
'rquest',
'rustls',
'urllib3',
'urllib3.connectionpool',
]
for _lg in LOGGERS_TO_SUPPRESS:
logger_obj = logging.getLogger(_lg)
logger_obj.setLevel(logging.ERROR)
# Prevent these logs from propagating to the root logger
logger_obj.propagate = False
# Capture warnings from the warnings module (optional, helps centralize output)
if hasattr(logging, 'captureWarnings'):
logging.captureWarnings(True)
[docs]
def get_logger(name: str | None = 'KodeAgent') -> logging.Logger:
"""Get a logger for KodeAgent.
Returns:
A logger instance.
"""
logging.basicConfig(
level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logging.getLogger('LiteLLM').setLevel(logging.WARNING)
logging.getLogger('langfuse').disabled = True
return logging.getLogger(name)
# Get a logger for the current module
logger = logging.getLogger('KodeAgent')
[docs]
def read_prompt(filename: str) -> str:
"""Reads a prompt from the `prompts` directory.
Args:
filename: Name of the prompt file to read.
Returns:
The content of the prompt file as a string.
Raises:
FileNotFoundError: If the prompt file does not exist.
RuntimeError: If there is an error reading the file.
"""
prompt_path = os.path.join(os.path.dirname(__file__), 'prompts', filename)
try:
with open(prompt_path, encoding='utf-8') as f:
return f.read()
except FileNotFoundError as fnfe:
raise FileNotFoundError(
f'Prompt file `{filename}` not found in the prompts directory: {prompt_path}'
) from fnfe
except Exception as e:
raise RuntimeError(f'Error reading prompt file `{filename}`: {e}') from e
[docs]
def is_it_url(path: str) -> bool:
"""Check whether a given path is a URL.
Args:
path: The path.
Returns:
`True` if it's a URL; `False` otherwise.
"""
return path.startswith(('http://', 'https://'))
[docs]
def detect_file_type(url: str) -> str:
"""Identify the content/MIME type of file pointed by a URL.
Args:
url: The URL to the file.
Returns:
The detected MIME type or `Unknown file type`.
"""
try:
# Step 1: Try HEAD request to get Content-Disposition
response = requests.head(url, allow_redirects=True, timeout=15)
content_disposition = response.headers.get('Content-Disposition')
if content_disposition and 'filename=' in content_disposition:
file_name = content_disposition.split('filename=')[1].strip()
file_extension = file_name.split('.')[-1]
return file_extension # If this works, return immediately
# Step 2: If HEAD didn't give useful info, send GET request for more details
response = requests.get(url, stream=True, timeout=20)
content_type = response.headers.get('Content-Type')
if content_type and content_type != 'application/json': # Avoid false positives
return content_type
return 'Unknown file type'
except requests.RequestException as e:
logger.error('Error detecting file type: %s', str(e))
return 'Unknown file type'
[docs]
def is_image_file(file_type) -> bool:
"""Identify whether a given MIME type is an image.
Args:
file_type: The file/content type.
Returns:
`True` if an image file; `False` otherwise.
"""
return file_type.startswith('image/')
[docs]
async def call_llm(
model_name: str,
litellm_params: dict,
messages: list[dict],
response_format: type[pyd.BaseModel] | None = None,
trace_id: str | None = None,
max_retries: int = DEFAULT_MAX_LLM_RETRIES,
usage_tracker: Any | None = None,
component_name: str = 'unknown',
) -> str | None:
"""Call the LLM with the given parameters and response format.
Args:
model_name: The name of the LLM model to use.
litellm_params: Dictionary of parameters to pass to litellm.
messages: List of message dictionaries.
response_format: Optional pydantic model for structured output.
trace_id: Optional trace ID for observability.
max_retries: Maximum number of retries for the LLM call.
usage_tracker: Optional UsageTracker instance to record usage.
component_name: Name of the component making the call (for tracking).
Returns:
The LLM response as string.
Raises:
RetryError: If the LLM call fails after maximum retries.
ValueError: If the LLM returns an empty or invalid response body.
"""
params = {'model': model_name, 'messages': messages}
if response_format:
params['response_format'] = response_format
# Add a timeout to prevent indefinite hangs
if 'timeout' not in litellm_params:
params['timeout'] = 30 # seconds
params.update(litellm_params)
try:
# Use AsyncRetrying to handle retries in a non-blocking way
async for attempt in AsyncRetrying(
stop=stop_after_attempt(max_retries),
wait=wait_random_exponential(multiplier=1, max=60),
retry=retry_if_exception_type(Exception),
before_sleep=before_sleep_log(logger, logging.WARNING),
):
with attempt:
# Use the asynchronous litellm call
response = await litellm.acompletion(
**params,
metadata={
'trace_id': str(trace_id) if trace_id else None,
'trace_name': 'kodeagent',
'generation_name': 'kodeagent-generation',
'generation_metadata': {
'response_format': (
response_format.__name__ if response_format else 'None'
),
'attempt': attempt.retry_state.attempt_number,
},
'tags': [model_name],
},
)
# Check for empty choices list
if not response.choices or len(response.choices) == 0:
raise ValueError('LLM returned an empty choices list.')
# Check for empty content
response_content = response.choices[0].message.content
if not response_content or not response_content.strip():
raise ValueError('LLM returned an empty or invalid response body.')
token_usage = {
'cost': response._hidden_params.get('response_cost', 0.0),
'prompt_tokens': response.usage.get('prompt_tokens', 0),
'completion_tokens': response.usage.get('completion_tokens', 0),
'total_tokens': response.usage.get('total_tokens', 0),
}
logger.info(token_usage)
# Record usage if tracker provided
if usage_tracker:
try:
metrics = UsageMetrics(
prompt_tokens=token_usage['prompt_tokens'],
completion_tokens=token_usage['completion_tokens'],
total_tokens=token_usage['total_tokens'],
cost=token_usage['cost'] or 0.0,
)
await usage_tracker.record_usage(component_name, metrics)
except Exception as e:
logger.warning('Failed to record usage: %s', str(e))
return response_content
except RetryError:
raise
except Exception as e:
logger.exception('LLM call failed after repeated attempts: %s', str(e), exc_info=True)
# print('\n\ncall_llm MESSAGES:\n', '\n'.join([str(msg) for msg in messages]), '\n\n')
raise ValueError('Failed to get a valid response from LLM after multiple retries.') from e
[docs]
def make_user_message(text_content: str, files: list[str] | None = None) -> list[dict[str, Any]]:
"""Create a single user message to be sent to LiteLLM.
Args:
text_content: The text content of the message.
files: An optional list of file paths or URLs, which can include images
or other file types.
Returns:
A list of dict items representing the messages.
"""
content: list[dict[str, Any]] = [{'type': 'text', 'text': str(text_content)}]
message: list[dict[str, Any]] = [{'role': 'user'}]
if files:
for item in files:
is_image = False
if is_it_url(item):
if any(
ext in item.lower()
for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']
) or is_image_file(detect_file_type(item)):
is_image = True
elif os.path.isfile(item):
try:
mime_type, _ = mimetypes.guess_type(item)
if mime_type and 'image' in mime_type:
is_image = True
except Exception:
logger.error(
'Error guessing MIME type for local file %s...will ignore it',
item,
exc_info=True,
)
# If an error occurs, treat it as not an image to continue processing
is_image = False
if is_image:
if is_it_url(item):
content.append({'type': 'image_url', 'image_url': {'url': item}})
elif os.path.isfile(item):
try:
with open(item, 'rb') as img_file:
encoded_image = base64.b64encode(img_file.read()).decode('utf-8')
try:
mime_type, _ = mimetypes.guess_type(item)
except Exception:
logger.warning(
'Could not guess MIME type, defaulting to octet-stream',
exc_info=True,
)
mime_type = 'application/octet-stream'
mime_type = mime_type if mime_type else 'application/octet-stream'
content.append(
{
'type': 'image_url',
'image_url': {'url': f'data:{mime_type};base64,{encoded_image}'},
}
)
except FileNotFoundError:
logger.error('Image file not found: %s...will ignore it', item)
except Exception as e:
logger.error(
'Error processing local image %s: %s...will ignore it', item, e
)
else:
logger.error('Invalid image file path or URL: %s...will ignore it', item)
else: # Handle as a general file or URL (not an image)
if is_it_url(item):
content.append({'type': 'text', 'text': f'File URL: {item}'})
elif os.path.isfile(item):
# Attempt to read the file as UTF-8 text.
# UnicodeDecodeError is the reliable signal that a file is binary —
# more robust than mimetypes.guess_type(), which is extension-based,
# unreliable on Windows (e.g. .csv often has no registered MIME type),
# and returns None for any unknown extension.
try:
with open(item, encoding='utf-8') as f:
file_content = f.read()
if len(file_content) <= MAX_FILE_CONTENT_LENGTH:
content.append(
{
'type': 'text',
'text': f'File {item} content:\n{file_content}',
}
)
else:
logger.warning(
'File `%s` content (%d chars) exceeds the %d-char threshold '
'— including path reference only.',
item,
len(file_content),
MAX_FILE_CONTENT_LENGTH,
)
content.append(
{
'type': 'text',
'text': (
f'Input file: {item} (file too large to include inline;'
' read it directly from the path)'
),
}
)
except UnicodeDecodeError:
# Binary file — include only the path reference
logger.debug(
'File `%s` is not valid UTF-8 text (binary?) — including path only.',
item,
)
content.append({'type': 'text', 'text': f'Input file: {item}'})
except Exception:
logger.error(
'Error reading file `%s`...will fallback to path only',
item,
exc_info=True,
)
content.append({'type': 'text', 'text': f'Input file: {item}'})
else:
logger.error('Invalid file path or URL: %s...will ignore it', item)
message[0]['content'] = content
return message
[docs]
def combine_user_messages(messages: list) -> list:
"""Combines consecutive user messages into a single message with a list of content items.
Returns:
A new list of messages with combined user messages.
"""
combined = []
for msg in messages:
if msg.get('role') == 'user':
if combined and combined[-1].get('role') == 'user':
# Merge content lists
prev_content = combined[-1]['content']
curr_content = msg.get('content', [])
if not isinstance(prev_content, list):
prev_content = [prev_content]
if not isinstance(curr_content, list):
curr_content = [curr_content]
combined[-1]['content'] = prev_content + curr_content
else:
# Ensure content is a list
content = msg.get('content', [])
if not isinstance(content, list):
content = [content]
combined.append({'role': 'user', 'content': content})
else:
combined.append(msg)
return combined
[docs]
def clean_json_string(json_str: str) -> str:
"""Clean and repair common JSON formatting issues from LLM responses.
Args:
json_str: Potentially malformed JSON string
Returns:
Cleaned JSON string
"""
if not json_str or not isinstance(json_str, str):
return json_str
# Remove Markdown code blocks
json_str = re.sub(r'^```json\s*', '', json_str, flags=re.MULTILINE)
json_str = re.sub(r'^```\s*', '', json_str, flags=re.MULTILINE)
json_str = json_str.strip()
# Try to find the actual JSON object: look for the first { and last }
start_idx = json_str.find('{')
end_idx = json_str.rfind('}')
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
json_str = json_str[start_idx : end_idx + 1]
# Remove trailing quotes/whitespace
json_str = re.sub(r'[\'"\s]*$', '', json_str)
# Fix common escaping issues
# Sometimes LLMs add extra backslashes or quotes
json_str = json_str.replace("\\'", "'") # Fix over-escaped single quotes
return json_str.strip()
[docs]
def parse_param_descriptions(doc: str) -> dict[str, str]:
"""Extract per-parameter descriptions from a docstring.
Supports Google-style (Args:) and Sphinx-style (:param name:) formats.
Args:
doc: The docstring to parse.
Returns:
A dictionary mapping parameter names to their descriptions.
"""
param_docs: dict[str, str] = {}
if not doc:
return param_docs
# Google-style: Args: / Parameters: section
args_section = re.search(r'(?:Args|Parameters):\s*(.*)', doc, re.DOTALL | re.IGNORECASE)
if args_section:
args_text = args_section.group(1)
for line in args_text.split('\n'):
line = line.strip()
if not line:
continue
match = re.match(r'(\w+)\s*(?:\(.*?\))?\s*:\s*(.*)', line)
if match:
param_docs[match.group(1)] = match.group(2).strip()
# Sphinx-style: :param name: description
if not param_docs:
for m in re.finditer(r':param\s+(\w+):\s*(.*)', doc):
param_docs[m.group(1)] = m.group(2).strip()
return param_docs
VALID_ROLES = {'system', 'user', 'assistant', 'tool'}
[docs]
def validate_chat_history(history: list[dict], tool_names: set[str] | None = None) -> None:
"""Validate that a provided chat history is OpenAI-compliant.
Performs stringent structural and semantic checks on each message. Raises
``ValueError`` with a precise description on the first problem found.
Args:
history: The chat history to validate. Must be a non-empty ``list[dict]``.
tool_names: Optional set of tool names registered on the agent. When
provided, tool call names not in this set emit a ``logger.warning``.
Raises:
ValueError: If the history fails any structural or compliance check.
"""
if not isinstance(history, list):
raise ValueError(f'chat_history must be a list[dict], got {type(history).__name__}.')
if not history:
raise ValueError('chat_history must not be empty.')
system_seen_at: int | None = None
requested_tool_call_ids: set[str] = set()
for idx, msg in enumerate(history):
if not isinstance(msg, dict):
raise ValueError(
f'chat_history[{idx}]: each message must be a dict, got {type(msg).__name__}.'
)
role = msg.get('role')
if role not in VALID_ROLES:
raise ValueError(
f"chat_history[{idx}]: 'role' must be one of {sorted(VALID_ROLES)}, got {role!r}."
)
# ── system message rules ──────────────────────────────────────────
if role == 'system':
if idx != 0:
raise ValueError(
f'chat_history[{idx}]: system message must be at index 0, not at index {idx}.'
)
system_seen_at = idx
if msg.get('content') is None:
raise ValueError(
f"chat_history[{idx}]: system message must have a non-None 'content'."
)
# ── user message rules ────────────────────────────────────────────
elif role == 'user':
if msg.get('content') is None:
raise ValueError(
f"chat_history[{idx}]: user message must have a non-None 'content'."
)
# ── assistant message rules ───────────────────────────────────────
elif role == 'assistant':
tool_calls = msg.get('tool_calls')
content = msg.get('content')
if tool_calls is None and content is None:
raise ValueError(
f'chat_history[{idx}]: assistant message must have either '
f"'content' or 'tool_calls'."
)
if tool_calls is not None:
if not isinstance(tool_calls, list):
raise ValueError(
f"chat_history[{idx}]: 'tool_calls' must be a list, "
f'got {type(tool_calls).__name__}.'
)
for tc_idx, tc in enumerate(tool_calls):
if not isinstance(tc, dict):
raise ValueError(
f'chat_history[{idx}].tool_calls[{tc_idx}]: '
f'each tool call must be a dict, got {type(tc).__name__}.'
)
tc_id = tc.get('id')
if not tc_id or not isinstance(tc_id, str):
raise ValueError(
f'chat_history[{idx}].tool_calls[{tc_idx}]: '
f"missing or empty 'id' (string required)."
)
# Track for subsequent tool result validation
requested_tool_call_ids.add(tc_id)
if tc.get('type') != 'function':
raise ValueError(
f'chat_history[{idx}].tool_calls[{tc_idx}]: '
f"'type' must be 'function', got {tc.get('type')!r}."
)
fn = tc.get('function')
if not isinstance(fn, dict):
raise ValueError(
f'chat_history[{idx}].tool_calls[{tc_idx}]: '
f"'function' must be a dict, got {type(fn).__name__}."
)
fn_name = fn.get('name')
if not fn_name or not isinstance(fn_name, str):
raise ValueError(
f'chat_history[{idx}].tool_calls[{tc_idx}]: '
f"'function.name' must be a non-empty string."
)
if not isinstance(fn.get('arguments'), str):
raise ValueError(
f'chat_history[{idx}].tool_calls[{tc_idx}]: '
f"'function.arguments' must be a string (JSON-encoded)."
)
# Warn about unknown tool names
if tool_names is not None and fn_name not in tool_names:
logger.warning(
'chat_history[%d].tool_calls[%d]: tool name %r not registered '
'on this agent (known: %s). The LLM will not re-call it.',
idx,
tc_idx,
fn_name,
sorted(tool_names),
)
# ── tool result message rules ─────────────────────────────────────
elif role == 'tool':
tc_id = msg.get('tool_call_id')
if not tc_id or not isinstance(tc_id, str):
raise ValueError(
f"chat_history[{idx}]: tool message missing or empty 'tool_call_id' "
f'(string required).'
)
name = msg.get('name')
if not name or not isinstance(name, str):
raise ValueError(
f"chat_history[{idx}]: tool message missing or empty 'name' (string required)."
)
if not isinstance(msg.get('content'), str):
raise ValueError(
f"chat_history[{idx}]: tool message 'content' must be a string, "
f'got {type(msg.get("content")).__name__}.'
)
# Ensure this result refers to a tool call that was actually requested
if tc_id not in requested_tool_call_ids:
raise ValueError(
f'chat_history[{idx}]: tool message refers to unknown tool_call_id {tc_id!r}.'
)
# ── pending tool call check ───────────────────────────────────────────
# Walk backwards; if the last assistant message has tool_calls, verify
# that every id appears in a subsequent tool message.
tool_result_ids: set[str] = set()
for msg in reversed(history):
role = msg.get('role')
if role == 'tool':
tool_result_ids.add(msg.get('tool_call_id', ''))
elif role == 'assistant':
pending = msg.get('tool_calls')
if pending:
unresolved = [
tc['id']
for tc in pending
if isinstance(tc, dict) and tc.get('id') not in tool_result_ids
]
if unresolved:
raise ValueError(
f'chat_history ends with unresolved tool call(s): '
f'{unresolved}. Each tool_call must have a corresponding '
f"'tool' message before appending new messages."
)
# Stop at the first assistant message from the end
break