"""Use Claude to analyse and describe a given ARC task"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_describe.ipynb.

# %% auto 0
__all__ = ['sp_direct', 'sp_indiv', 'sp_merge', 'Description', 'ShapeExtractor', 'DescriptionGenerator']

# %% ../nbs/02_describe.ipynb 5
from .task import ArcTask, ArcPair
from .ocm import Color
from .utils import parse_from_xml
from claudette import *
from fastcore.utils import *
from fastcore.meta import *
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, AsyncAnthropicVertex
from anthropic.types import Usage
import asyncio
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
from scipy import ndimage

# %% ../nbs/02_describe.ipynb 8
from toolslm.funccall import get_schema

# %% ../nbs/02_describe.ipynb 9
@patch
async def __call__(self:AsyncChat,
        pr=None,  # Prompt / message
        temp=0, # Temperature
        maxtok=4096, # Maximum tokens
        stream=False, # Stream response?
        prefill='', # Optional prefill to pass to Claude as start of its response
        **kw):
    await self._append_pr(pr)
    if self.tools: kw['tools'] = [get_schema(o) for o in self.tools]
    res = await self.c(self.h, stream=stream, prefill=prefill, sp=self.sp, temp=temp, maxtok=maxtok, **kw)
    if stream: return self._stream(res)
    self.h += mk_toolres(self.c.result, ns=self.tools)  #, obj=self)
    return res

# %% ../nbs/02_describe.ipynb 11
sp_direct = """\
You are an expert at solving visual IQ puzzles involving transformations between input and output colored grids. \
Your task is to analyze images of puzzles and provide concise, accurate solutions. To solve a given puzzle, follow these steps:

1. INITIAL ANALYSIS
   a) Analyze the input and output grids carefully
   b) Note grid dimensions and any patterns in how they change
   c) Identify if there's a consistent background color
   d) Check if colors maintain consistent meanings across examples

2. ELEMENT IDENTIFICATION
   a) List all elements present in both input and output grids, counting them explicitly
   b) Note whether elements maintain their properties (color, size, shape) across examples
   c) Identify any hierarchical relationships between elements

3. TRANSFORMATION ANALYSIS
   a) Compare input and output grids side by side
   b) Note specific transformations for each element
   c) Analyze whether transformations apply:
      - Globally to the entire grid
      - To individual objects/regions
      - To specific color groups
      - To relationships between objects

4. PATTERN RECOGNITION
   a) Check for symmetry properties:
      - Rotational invariance
      - Mirror symmetry
      - Translation invariance
      - Scale invariance
   b) Look for pattern periodicity or repetition
   c) Analyze edge cases; e.g.:
      - Grid boundary interactions
      - Overlapping objects/patterns

5. CORE KNOWLEDGE PRIORS
   Bear in mind that this class of puzzles are solvable assuming only basic cognitive principles:
   a) Objectness:
      - Object cohesion
      - Object persistence
      - Influence via contact
   b) Goal-directedness (start and end states)
   c) Numbers and Counting (typically <10)
   d) Basic Geometry and Topology:
      - Lines, shapes, and basic geometric properties
      - Symmetries (rotational, mirror, translational)
      - Transformations (rotation, translation, scaling)
      - Spatial relationships and connectivity
      - Pattern periodicity and repetition
   e) Color consistency and meaning
   f) Hierarchical relationships

6. GRID PROPERTIES ANALYSIS
   a) Analyze dimension relationships:
      - Whether input/output dimensions are preserved
      - What determines output dimensions if different
   b) Consider how grid size affects transformation rules
   c) Look for patterns in provided dimension metadata

7. RULE FORMULATION
   a) Develop a general transformation rule
   b) Ensure the rule is:
      - Abstract and applicable to all examples
      - Deterministic (same input always produces same output)
      - Algorithmically implementable
   c) Consider alternative explanations and rule them out
   d) Test rule against edge cases

8. VALIDATION
   a) Double-check rule against all provided examples
   b) Verify rule consistency across different grid sizes
   c) Confirm rule handles all identified edge cases
   d) Refine rule if necessary

Present your thought process in <reasoning> tags. \
This is where you can break down your observations, reasoning, and alternative explanations in detail. \
It's OK for this section to be quite long and to include explicit counting and detailed analysis.

After your analysis, provide a concise solution summary in <description> tags. This summary should:
- Be no more than 4-5 sentences long
- Clearly describe first (in generality) the properties/objects in the input grids
- Then describe properties of the output grids and how an output grid is constructed from its input grid
- Avoid instance-specific descriptions or if-else statements
- Capture all key aspects of the transformation
- Use precise, unambiguous language

Your goal is to provide a clear, concise, and accurate description that captures the essence of the puzzle's \
transformation rule while being general enough to work across all examples. Remember to close xml tags.
"""

# %% ../nbs/02_describe.ipynb 14
def _shape_table(pairs: List[ArcPair]  # List of training example pairs
                ) -> str:              # string containing a table of grid shapes
    header = "| Input Shape | Output Shape |\n|------------|-------------|\n"
    rows = "\n".join(f"| {str(i.shape):<10} | {str(o.shape):<11} |" for i, o in pairs)
    return header + rows

# %% ../nbs/02_describe.ipynb 19
@dataclass
class Description:
    "A single description of an ARC task."
    content: str            # The full description response (including reasoning)
    chats: List[AsyncChat]  # Store all chats used to generate this description
    method: str             # 'direct' or 'indirect'

    @property
    def d(self) -> str:
        "Extract just the description from the full response"
        return parse_from_xml(self.content, 'description')
    
    @property
    def usage(self) -> Usage:
        "Get combined token usage for this description."
        return sum((chat.use for chat in self.chats), 
                  start=Usage(input_tokens=0, output_tokens=0))
    
    @property 
    def cost(self) -> float:
        "Get total cost in USD for this description."
        return sum(chat.cost for chat in self.chats)

# %% ../nbs/02_describe.ipynb 21
def _create_client(client_type,
                   client_kwargs
                  ) -> Union[AsyncAnthropic, AsyncAnthropicBedrock, AsyncAnthropicVertex]:
    "Create appropriate async client based on configuration."
    if client_type == "bedrock":
        return AsyncAnthropicBedrock(**client_kwargs)
    elif client_type == "vertex":
        return AsyncAnthropicVertex(**client_kwargs)
    else:  # default to standard Anthropic
        return AsyncAnthropic(
            default_headers={'anthropic-beta': 'prompt-caching-2024-07-31'},
            **client_kwargs
        )

# %% ../nbs/02_describe.ipynb 22
def _create_chat(model, client, sp: str, tools: Optional[list] = None) -> AsyncChat:
    "Create a new chat instance."
    cli = AsyncClient(model, client)
    return AsyncChat(cli=cli, sp=sp, tools=tools)

# %% ../nbs/02_describe.ipynb 23
async def _describe_direct(
    task: ArcTask | str,                        # Either an ArcTask object or a task ID string
    model: str = 'claude-3-5-sonnet-20241022',  # Model identifier (defaults to Sonnet 3.5)
    client_type: str = 'anthropic',             # 'anthropic', 'bedrock', or 'vertex'
    client_kwargs: Dict = {},                   # Optional kwargs for client instantiation
    sp: str | None = None,                      # Custom system prompt (if None, uses `sp_direct`)
    temp: float = 0.0,                          # Sampling temperature for generation
    prefill: str = '<reasoning>',               # Text to prefill the assistant's response with
    **kwargs                                    # Additional arguments passed to Chat.__call__
) -> Description:                               # Container holding description and the chat object
    "Generate a description of an ARC task from all examples at once"
    
    if sp is None: sp = sp_direct
    if isinstance(task, str): task = ArcTask(task)

    # Set up chat and get description
    client = _create_client(client_type, client_kwargs)
    chat = _create_chat(model, client, sp)
    
    pr = f"The example grids in the image have the following shapes:\n{_shape_table(task.train)}"
    in_cols = np.unique(np.hstack([p.input.data.flatten() for p in task.train]))
    out_cols = np.unique(np.hstack([p.output.data.flatten() for p in task.train]))
    pr += f"\nThe colors present in the input grids are: {', '.join(Color.colors[i] for i in in_cols)}\n"
    pr += f"\nThe colors present in the output grids are: {', '.join(Color.colors[i] for i in out_cols)}"
    
    r = await chat([task.plot(to_base64=True), pr],
                   prefill=prefill,
                   temp=temp,
                   **kwargs)
    
    return Description(
        content=r.content[0].text,
        chats=[chat],
        method='direct'
    )

# %% ../nbs/02_describe.ipynb 26
sp_indiv = """\
You are an expert puzzle analyst tasked with deciphering complex visual transformation puzzles. \
You analyze and describe the patterns and distinct shapes contained in input and output grids \
and judge what transformation maps input grids to output grids. 

Analysis Process:
1. INITIAL VISUAL ANALYSIS
   First, analyze the grids visually without tools:
   - Note obvious patterns, shapes, and colors
   - Form initial hypotheses about the transformation
   - Consider whether the transformation appears to be:
     * Global (affecting the entire grid)
     * Local (affecting specific shapes/regions)
     * Pattern-based (involving repetition or rules)
   - Think about why certain changes might occur

2. HYPOTHESIS REFINEMENT
   If your initial analysis leaves uncertainties, consider whether the shape extraction tool would help:
   - For complex or irregular shapes that might be related
   - To verify suspected rotations, reflections, repetitions, or scaling
   - To precisely analyze spatial relationships
   - To confirm pattern hypotheses
   Note that the shape extraction tool can only extract connected regions of a single color; often a "shape" in an ARC task is a distinct multi-colored region.

Use the shape extraction tool judiciously to test specific hypotheses rather than as a first resort. Avoid using it when:
- Shapes are simple and easily describable
- The pattern is clearly global
- The transformation is obvious visually
- Colors appear to be background or noise rather than meaningful shapes

Remember, these puzzles are designed to be solvable using only "Core Knowledge priors":
1. Objectness priors (object cohesion, persistence, and influence via contact)
2. Goal-directedness prior (conceptualize as intentional processes with start/end states)
3. Numbers and Counting priors (quantities typically <10)
4. Basic Geometry and Topology priors (lines, shapes, symmetries, rotations, translations, scaling, spatial relationships)

Present your analysis in this format:

<initial_analysis>
Based on visual inspection:
- Overview of grid properties and obvious patterns
- Initial transformation hypothesis
- Key uncertainties or questions
- Whether tool analysis would be helpful and why
</initial_analysis>

If tool use is required, begin tool-calling process now. \
Once you have gathered the required information via tool calling, proceed with your analys as follows:

<detailed_analysis>
Input:
- Dimensions: [x, y]
- Background: [color if applicable]
- Shapes: [list major shapes with properties and positions]
  * Shape relationships: [note any rotational/reflectional/scaling relationships]
  * Spatial relationships: [relative positions, alignments, groupings]
  * Remember shapes can be multi-colored
- Colors: [list with roles]
- Notable patterns: [recurring elements, global patterns]

Output:
- [Same structure as Input]

Transformations:
- Size changes: [grid or shape scaling]
- Shape changes: [rotations, reflections, splits, merges]
- Color changes: [role changes, new colors, color relationships]
- Position changes: [translations, relative position changes]
- Pattern changes: [how global patterns transform]
</detailed_analysis>

<final_hypothesis>
Provide your final transformation rule hypothesis, considering:
- Whether the transformation is local or global
- What properties are preserved
- Why certain changes occur
- How the transformation aligns with core knowledge priors
- Any remaining uncertainties
</final_hypothesis>

Note: It may not be possible to determine the exact relationship with full confidence. \
Make your best guess considering the points above, and note any uncertainties in your hypothesis. \
Remember to close xml tags.
"""

# %% ../nbs/02_describe.ipynb 28
def _pair_prompt(pair: ArcPair,  # A single training example pair from an ARC task
                example_idx: int = 0,  # The index of the training example
               ):
    inp, outp = pair
    base_idx = example_idx * 2

    return f"""\
This image contains two grids from a visual IQ puzzle:

input grid (left):
- idx: {base_idx}
- size: {inp.shape[0]}x{inp.shape[1]}
- colors: {', '.join(Color.colors[i] for i in np.unique(inp.data))}

output grid (right):
- idx: {base_idx + 1}
- size: {outp.shape[0]}x{outp.shape[1]}
- colors: {', '.join(Color.colors[i] for i in np.unique(outp.data))}\
"""

# %% ../nbs/02_describe.ipynb 31
class ShapeExtractor:
    """Extract shapes from grid pairs for analysis."""
    def __init__(self, task: ArcTask):
        # Store flattened list of grids in order 
        self.grids = [grid for pair in task.train for grid in pair]
        
    def extract_shapes(
        self,
        grid_idx: int,  # Index of the target grid
        color: str,  # Color of shapes to extract
        include_diagonal: bool,  # Consider diagonally adjacent cells as connected?
    ) -> list:  # List of extracted shapes (boolean arrays) and their positions
        """Extract contiguous regions of a specified color from a grid."""
        ORTH = np.array([[0,1,0], [1,1,1], [0,1,0]])
        DIAG = np.ones((3,3))
        
        try:
            arr = self.grids[grid_idx].data.copy()
        except IndexError as e:
            raise IndexError(f"Invalid grid_index {grid_idx}. Must be between 0 and {len(self.grids)-1}") from e
    
        value = Color.colors.index(color)
        mask = (arr == value)
        structure = DIAG if include_diagonal else ORTH
        labeled, _ = ndimage.label(mask, structure=structure)
        slices = ndimage.find_objects(labeled)
    
        regions = [[np.sum(mask[s]), s, labeled[s] == i+1] for i, s in enumerate(slices)]
        return [(r[2], (r[1][0].start, r[1][1].start)) for r in regions]

# %% ../nbs/02_describe.ipynb 33
@patch
@delegates(AsyncChat.__call__)
async def toolloop(self:AsyncChat,
             pr, # Prompt to pass to Claude
             max_steps=10, # Maximum number of tool requests to loop through
             trace_func:Optional[callable]=None, # Function to trace tool use steps (e.g `print`)
             cont_func:Optional[callable]=noop, # Function that stops loop if returns False
             **kwargs):
    "Add prompt `pr` to dialog and get a response from Claude, automatically following up with `tool_use` messages"
    n_msgs = len(self.h) - 1
    r = await self(pr, **kwargs)
    for i in range(max_steps):
        if r.stop_reason!='tool_use': break
        if trace_func: trace_func(self.h[n_msgs:]); n_msgs = len(self.h) - 1
        r = await self(**kwargs)
        if not (cont_func or noop)(self.h[-2]): break
    if trace_func: trace_func(self.h[n_msgs:])
    return r

# %% ../nbs/02_describe.ipynb 38
sp_merge = """\
You are an expert puzzle analyst tasked with deciphering complex visual transformation puzzles. \
Your goal is to infer the general rule that governs how an input grid is transformed into an output grid \
based on multiple descriptions of individual grid pairs.

The descriptions you'll analyze were generated by observers who each saw only one pair of grids. \
They had access to a shape extraction tool and were instructed to analyze shapes, patterns, and transformations in detail.

Analysis Steps:

1. PATTERN IDENTIFICATION
   For each transformation aspect, analyze across all descriptions:
   - Grid size relationships
   - Shape transformations (rotations, reflections, scaling)
   - Color role patterns
   - Spatial relationship preservation
   - Global vs local patterns

2. CONSISTENCY ANALYSIS
   For each observed pattern, rate:
   - Frequency: How many descriptions support it?
   - Consistency: Are there any contradictions?
   - Completeness: Does it explain all observations?
   - Simplicity: Is it the simplest explanation?

3. TRANSFORMATION CLASSIFICATION
   Determine if transformations are:
   - Local (applying to individual shapes)
   - Global (applying to entire grid)
   - Hierarchical (different rules at different scales)
   - Composite (multiple transformations applied sequentially)

4. RELATIONSHIP ANALYSIS
   Look for patterns in:
   - How shapes relate to each other (rotation, reflection, scaling)
   - How shapes interact with grid boundaries
   - How colors relate to shapes and patterns
   - How relative positions are maintained or changed

5. EDGE CASE CONSIDERATION
   Consider how the rule would handle:
   - Overlapping/touching shapes
   - Shapes at grid boundaries
   - Any other relevant edge cases

Show your analysis process in <reasoning> tags, including:
1. List of all unique characteristics across descriptions
2. Pattern frequency and consistency analysis
3. Alternative hypotheses considered
4. Evidence supporting your chosen rule
5. Potential edge cases and how they're handled

After your analysis, provide a concise solution summary in <description> tags. This summary should:
- Be no more than 4-5 sentences long
- Clearly describe first (in generality) the properties/objects in the input grids
- Then describe properties of the output grids and how an output grid is constructed from its input grid
- Avoid instance-specific descriptions or if-else statements
- Capture all key aspects of the transformation
- Use precise, unambiguous language, i.e. not use terms such as "follows a certain transformation", "according to some rule", etc.

If multiple interpretations are possible, choose the simplest one that explains all observations while respecting these constraints. \
Remember to close xml tags.
"""

# %% ../nbs/02_describe.ipynb 39
async def _describe_indirect(
    task: ArcTask | str,                        # Either an ArcTask object or a task ID string
    model: str = 'claude-3-5-sonnet-20241022',  # Model identifier (defaults to Sonnet 3.5)
    client_type: str = 'anthropic',             # 'anthropic', 'bedrock', or 'vertex'
    client_kwargs: Dict = {},                   # Optional kwargs for client instantiation
    sp: str | None = None,                      # Custom system prompt for individual analysis (if None, uses `sp_indiv`)
    sp_combine: str | None = None,              # Custom system prompt for synthesizing from independent descriptions (if None, uses `sp_combine`)
    temp: float = 0.0,                          # Sampling temperature for generation
    tools: Optional[list] = None,               # List of tools to make available to Claude (defaults to `[ShapeExtractor.extract_shapes]`)
    **kwargs                                    # Additional arguments passed to AsyncChat.__call__
) -> Description:                               # Container holding description and the list of chats used
    "Generate a description of an ARC task by analyzing examples independently and then combining insights."

    if isinstance(task, str): task = ArcTask(task)

    # Use default prompts if none provided
    if sp is None: sp = sp_indiv
    if sp_combine is None: sp_combine = sp_merge
    # Create shape extractor
    if tools is None:
        extractor = ShapeExtractor(task)
        tools = [extractor.extract_shapes]
    
    # Create chats for each example pair
    pair_clients = [_create_client(client_type, client_kwargs) for _ in task.train]
    pair_chats = [_create_chat(model, c, sp, tools) for c in pair_clients]
    
    # Process examples concurrently
    pair_tasks = [
        chat.toolloop([pair.plot(to_base64=True), _pair_prompt(pair, i)], temp=temp)
        for i, (chat, pair) in enumerate(zip(pair_chats, task.train))
    ]
    responses = await asyncio.gather(*pair_tasks)

     # Format and merge descriptions
    merge_chat = _create_chat(model, _create_client(client_type, client_kwargs), sp_combine)
    descs = '\n\n'.join(
        f'<description id="{i+1}">\n{r.content[0].text}\n</description>'
        for i, r in enumerate(responses)
    )
    merged = await merge_chat([descs], temp=temp)
    
    # Create description with all chats used
    all_chats = pair_chats + [merge_chat]
    return Description(
        content=merged.content[0].text,
        chats=all_chats,
        method='indirect'
    )

# %% ../nbs/02_describe.ipynb 42
class DescriptionGenerator:
    "Generates descriptions of ARC tasks using Claude."
    def __init__(self, 
                 model: str = "claude-3-5-sonnet-20241022",  # Model identifier (defaults to Sonnet 3.5)
                 client_type: str = "anthropic",             # 'anthropic', 'bedrock', or 'vertex'
                 client_kwargs: Optional[Dict] = None,       # Optional kwargs for client instantiation
                 direct_sp: Optional[str] = None,            # Custom system prompt for direct description (if None, uses `sp_direct`)
                 indirect_sp: Optional[str] = None,          # Custom system prompt for single pair description (if None, uses `sp_indiv`)
                 merge_sp: Optional[str] = None):            # Custom system prompt for synthesized description (if None, uses `sp_merge`)
        self.model = model
        self.client_type = client_type
        self.client_kwargs = client_kwargs or {}
        self.direct_sp = direct_sp or sp_direct
        self.indirect_sp = indirect_sp or sp_indiv
        self.merge_sp = merge_sp or sp_merge

    def _create_client(self) -> Union[AsyncAnthropic, AsyncAnthropicBedrock, AsyncAnthropicVertex]:
        "Create appropriate async client based on configuration."
        return _create_client(self.client_type, self.client_kwargs)

    def _create_chat(self, sp: str, tools: Optional[list] = None) -> AsyncChat:
        "Create a new chat instance."
        return _create_chat(self.model, self._create_client(), sp, tools)

# %% ../nbs/02_describe.ipynb 45
@patch
async def describe_direct(
    self: DescriptionGenerator,
    task: ArcTask | str,           # ARC task or task ID to describe
    n: int = 1,                    # No. of descriptions to generate
    temp: float = 0.5,             # Temperature for generation (higher for diversity)
    prefill: str = '<reasoning>',  # Text to prefill the assistant's response with
    **kwargs                       # Additional arguments passed to AsyncChat.__call__
) -> List[Description]:            # List of `Description` objects
    "Generate n direct descriptions of the task concurrently."
    tasks = [
        _describe_direct(task, self.model, self.client_type, self.client_kwargs, self.direct_sp, temp, prefill, **kwargs)
        for _ in range(n)
    ]
    return await asyncio.gather(*tasks)

# %% ../nbs/02_describe.ipynb 54
@patch
async def describe_indirect(
    self: DescriptionGenerator,
    task: ArcTask | str,           # ARC task or task ID to describe
    n: int = 1,                    # No. of descriptions to generate
    temp: float = 0.6,             # Temperature for generation (higher for diversity)
    tools: Optional[list] = None,  # List of tools to make available to Claude (defaults to `[ShapeExtractor.extract_shapes]`)
    **kwargs                       # Additional arguments passed to AsyncChat.__call__
) -> List[Description]:            # List of `Description` objects
    "Generate n direct descriptions of the task concurrently."
    tasks = [
        _describe_indirect(task, self.model, self.client_type, self.client_kwargs, self.indirect_sp, self.merge_sp, temp, tools, **kwargs)
        for _ in range(n)
    ]
    return await asyncio.gather(*tasks)

# %% ../nbs/02_describe.ipynb 65
@patch
async def describe_task(
    self: DescriptionGenerator,
    task: ArcTask | str,           # ARC task or task ID to describe
    n_direct: int = 1,             # No. of direct descriptions to generate
    n_indirect: int = 1,           # No. of indirect descriptions to generate
    temp: float = 0.7,             # Temperature for generation (higher for diversity)
    **kwargs
) -> List[Description]:            # List of `Description` objects  
    "Generate multiple descriptions of a task using one or both strategies concurrently."

    # Generate all descriptions concurrently
    descriptions = await asyncio.gather(
        self.describe_direct(task, n=n_direct, temp=temp, prefill=None, **kwargs),
        self.describe_indirect(task, n=n_indirect, temp=temp, tools=None, **kwargs)
    )

    return [d for d_list in descriptions for d in d_list]
